mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Edge extraction and Node Deduplication updates (#564)
* update tests * updated fact extraction * optimize node deduplication * linting * Update graphiti_core/utils/maintenance/edge_operations.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
e3f1c679f7
commit
ebee09b335
@ -63,6 +63,10 @@ class Person(BaseModel):
|
||||
occupation: str | None = Field(..., description="The person's work occupation")
|
||||
|
||||
|
||||
class IsPresidentOf(BaseModel):
|
||||
"""Relationship between a person and the entity they are a president of"""
|
||||
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
@ -84,6 +88,8 @@ async def main():
|
||||
source_description='Podcast Transcript',
|
||||
group_id=group_id,
|
||||
entity_types={'Person': Person},
|
||||
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
|
||||
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
|
||||
previous_episode_uuids=episode_uuids,
|
||||
)
|
||||
|
||||
|
@ -137,8 +137,12 @@ def nodes(context: dict[str, Any]) -> list[Message]:
|
||||
<ENTITIES>
|
||||
{json.dumps(context['extracted_nodes'], indent=2)}
|
||||
</ENTITIES>
|
||||
|
||||
<EXISTING ENTITIES>
|
||||
{json.dumps(context['existing_nodes'], indent=2)}
|
||||
</EXISTING ENTITIES>
|
||||
|
||||
For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
|
||||
For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
|
||||
|
||||
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
|
||||
|
||||
@ -152,9 +156,9 @@ def nodes(context: dict[str, Any]) -> list[Message]:
|
||||
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
|
||||
as an integer.
|
||||
|
||||
- If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
|
||||
- If an entity is a duplicate of one of the EXISTING ENTITIES, return the idx of the candidate it is a
|
||||
duplicate of.
|
||||
- If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
|
||||
- If an entity is not a duplicate of one of the EXISTING ENTITIES, return the -1 as the duplication_idx
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
@ -24,8 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
|
||||
|
||||
class Edge(BaseModel):
|
||||
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
|
||||
source_entity_name: str = Field(..., description='The name of the source entity of the fact.')
|
||||
target_entity_name: str = Field(..., description='The name of the target entity of the fact.')
|
||||
source_entity_id: int = Field(..., description='The id of the source entity of the fact.')
|
||||
target_entity_id: int = Field(..., description='The id of the target entity of the fact.')
|
||||
fact: str = Field(..., description='')
|
||||
valid_at: str | None = Field(
|
||||
None,
|
||||
@ -77,7 +77,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
||||
</CURRENT_MESSAGE>
|
||||
|
||||
<ENTITIES>
|
||||
{context['nodes']} # Each has: id, label (e.g., Person, Org), name, aliases
|
||||
{context['nodes']}
|
||||
</ENTITIES>
|
||||
|
||||
<REFERENCE_TIME>
|
||||
@ -94,8 +94,9 @@ Only extract facts that:
|
||||
- involve two DISTINCT ENTITIES from the ENTITIES list,
|
||||
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
|
||||
and can be represented as edges in a knowledge graph.
|
||||
- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
|
||||
could be classified into one of the provided fact types
|
||||
- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
|
||||
- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
|
||||
of the FACT TYPES
|
||||
|
||||
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
|
||||
|
||||
|
@ -92,8 +92,6 @@ async def extract_edges(
|
||||
extract_edges_max_tokens = 16384
|
||||
llm_client = clients.llm_client
|
||||
|
||||
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
||||
|
||||
edge_types_context = (
|
||||
[
|
||||
{
|
||||
@ -109,7 +107,7 @@ async def extract_edges(
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
'episode_content': episode.content,
|
||||
'nodes': [node.name for node in nodes],
|
||||
'nodes': [{'id': idx, 'name': node.name} for idx, node in enumerate(nodes)],
|
||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||
'reference_time': episode.valid_at,
|
||||
'edge_types': edge_types_context,
|
||||
@ -160,14 +158,16 @@ async def extract_edges(
|
||||
invalid_at = edge_data.get('invalid_at', None)
|
||||
valid_at_datetime = None
|
||||
invalid_at_datetime = None
|
||||
source_node_uuid = node_uuids_by_name_map.get(edge_data.get('source_entity_name', ''), '')
|
||||
target_node_uuid = node_uuids_by_name_map.get(edge_data.get('target_entity_name', ''), '')
|
||||
|
||||
if source_node_uuid == '' or target_node_uuid == '':
|
||||
source_node_idx = edge_data.get('source_entity_id', -1)
|
||||
target_node_idx = edge_data.get('target_entity_id', -1)
|
||||
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
|
||||
logger.warning(
|
||||
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_uuid} and target_node_uuid: {target_node_uuid} '
|
||||
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
|
||||
)
|
||||
continue
|
||||
source_node_uuid = nodes[source_node_idx].uuid
|
||||
target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
|
||||
|
||||
if valid_at:
|
||||
try:
|
||||
|
@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.llm_client.config import ModelSize
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
||||
from graphiti_core.prompts.dedupe_nodes import NodeResolutions
|
||||
from graphiti_core.prompts.extract_nodes import (
|
||||
ExtractedEntities,
|
||||
ExtractedEntity,
|
||||
@ -241,7 +241,25 @@ async def resolve_extracted_nodes(
|
||||
]
|
||||
)
|
||||
|
||||
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
|
||||
existing_nodes_dict: dict[str, EntityNode] = {
|
||||
node.uuid: node for result in search_results for node in result.nodes
|
||||
}
|
||||
|
||||
existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
|
||||
|
||||
existing_nodes_context = (
|
||||
[
|
||||
{
|
||||
**{
|
||||
'idx': i,
|
||||
'name': candidate.name,
|
||||
'entity_types': candidate.labels,
|
||||
},
|
||||
**candidate.attributes,
|
||||
}
|
||||
for i, candidate in enumerate(existing_nodes)
|
||||
],
|
||||
)
|
||||
|
||||
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
|
||||
|
||||
@ -255,23 +273,13 @@ async def resolve_extracted_nodes(
|
||||
next((item for item in node.labels if item != 'Entity'), '')
|
||||
).__doc__
|
||||
or 'Default Entity Type',
|
||||
'duplication_candidates': [
|
||||
{
|
||||
**{
|
||||
'idx': j,
|
||||
'name': candidate.name,
|
||||
'entity_types': candidate.labels,
|
||||
},
|
||||
**candidate.attributes,
|
||||
}
|
||||
for j, candidate in enumerate(existing_nodes_lists[i])
|
||||
],
|
||||
}
|
||||
for i, node in enumerate(extracted_nodes)
|
||||
]
|
||||
|
||||
context = {
|
||||
'extracted_nodes': extracted_nodes_context,
|
||||
'existing_nodes': existing_nodes_context,
|
||||
'episode_content': episode.content if episode is not None else '',
|
||||
'previous_episodes': [ep.content for ep in previous_episodes]
|
||||
if previous_episodes is not None
|
||||
@ -294,8 +302,8 @@ async def resolve_extracted_nodes(
|
||||
extracted_node = extracted_nodes[resolution_id]
|
||||
|
||||
resolved_node = (
|
||||
existing_nodes_lists[resolution_id][duplicate_idx]
|
||||
if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
|
||||
existing_nodes[duplicate_idx]
|
||||
if 0 <= duplicate_idx < len(existing_nodes)
|
||||
else extracted_node
|
||||
)
|
||||
|
||||
@ -309,70 +317,6 @@ async def resolve_extracted_nodes(
|
||||
return resolved_nodes, uuid_map
|
||||
|
||||
|
||||
async def resolve_extracted_node(
|
||||
llm_client: LLMClient,
|
||||
extracted_node: EntityNode,
|
||||
existing_nodes: list[EntityNode],
|
||||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_type: BaseModel | None = None,
|
||||
) -> EntityNode:
|
||||
start = time()
|
||||
if len(existing_nodes) == 0:
|
||||
return extracted_node
|
||||
|
||||
# Prepare context for LLM
|
||||
existing_nodes_context = [
|
||||
{
|
||||
**{
|
||||
'id': i,
|
||||
'name': node.name,
|
||||
'entity_types': node.labels,
|
||||
},
|
||||
**node.attributes,
|
||||
}
|
||||
for i, node in enumerate(existing_nodes)
|
||||
]
|
||||
|
||||
extracted_node_context = {
|
||||
'name': extracted_node.name,
|
||||
'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore
|
||||
}
|
||||
|
||||
context = {
|
||||
'existing_nodes': existing_nodes_context,
|
||||
'extracted_node': extracted_node_context,
|
||||
'entity_type_description': entity_type.__doc__
|
||||
if entity_type is not None
|
||||
else 'Default Entity Type',
|
||||
'episode_content': episode.content if episode is not None else '',
|
||||
'previous_episodes': [ep.content for ep in previous_episodes]
|
||||
if previous_episodes is not None
|
||||
else [],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.dedupe_nodes.node(context),
|
||||
response_model=NodeDuplicate,
|
||||
model_size=ModelSize.small,
|
||||
)
|
||||
|
||||
duplicate_id: int = llm_response.get('duplicate_node_id', -1)
|
||||
|
||||
node = (
|
||||
existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node
|
||||
)
|
||||
|
||||
node.name = llm_response.get('name', '')
|
||||
|
||||
end = time()
|
||||
logger.debug(
|
||||
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
async def extract_attributes_from_nodes(
|
||||
clients: GraphitiClients,
|
||||
nodes: list[EntityNode],
|
||||
|
@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.12.0pre4"
|
||||
version = "0.12.0"
|
||||
authors = [
|
||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||
|
Loading…
x
Reference in New Issue
Block a user