Edge extraction and Node Deduplication updates (#564)

* update tests

* updated fact extraction

* optimize node deduplication

* linting

* Update graphiti_core/utils/maintenance/edge_operations.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-06-06 12:28:52 -04:00 committed by GitHub
parent e3f1c679f7
commit ebee09b335
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 50 additions and 95 deletions

View File

@ -63,6 +63,10 @@ class Person(BaseModel):
occupation: str | None = Field(..., description="The person's work occupation")
class IsPresidentOf(BaseModel):
"""Relationship between a person and the entity they are a president of"""
async def main():
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
@ -84,6 +88,8 @@ async def main():
source_description='Podcast Transcript',
group_id=group_id,
entity_types={'Person': Person},
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
previous_episode_uuids=episode_uuids,
)

View File

@ -137,8 +137,12 @@ def nodes(context: dict[str, Any]) -> list[Message]:
<ENTITIES>
{json.dumps(context['extracted_nodes'], indent=2)}
</ENTITIES>
<EXISTING ENTITIES>
{json.dumps(context['existing_nodes'], indent=2)}
</EXISTING ENTITIES>
For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
@ -152,9 +156,9 @@ def nodes(context: dict[str, Any]) -> list[Message]:
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
as an integer.
- If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
- If an entity is a duplicate of one of the EXISTING ENTITIES, return the idx of the candidate it is a
duplicate of.
- If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
- If an entity is not a duplicate of one of the EXISTING ENTITIES, return the -1 as the duplication_idx
""",
),
]

View File

@ -24,8 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
class Edge(BaseModel):
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
source_entity_name: str = Field(..., description='The name of the source entity of the fact.')
target_entity_name: str = Field(..., description='The name of the target entity of the fact.')
source_entity_id: int = Field(..., description='The id of the source entity of the fact.')
target_entity_id: int = Field(..., description='The id of the target entity of the fact.')
fact: str = Field(..., description='')
valid_at: str | None = Field(
None,
@ -77,7 +77,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
</CURRENT_MESSAGE>
<ENTITIES>
{context['nodes']} # Each has: id, label (e.g., Person, Org), name, aliases
{context['nodes']}
</ENTITIES>
<REFERENCE_TIME>
@ -94,8 +94,9 @@ Only extract facts that:
- involve two DISTINCT ENTITIES from the ENTITIES list,
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
and can be represented as edges in a knowledge graph.
- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
could be classified into one of the provided fact types
- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
of the FACT TYPES
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.

View File

@ -92,8 +92,6 @@ async def extract_edges(
extract_edges_max_tokens = 16384
llm_client = clients.llm_client
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
edge_types_context = (
[
{
@ -109,7 +107,7 @@ async def extract_edges(
# Prepare context for LLM
context = {
'episode_content': episode.content,
'nodes': [node.name for node in nodes],
'nodes': [{'id': idx, 'name': node.name} for idx, node in enumerate(nodes)],
'previous_episodes': [ep.content for ep in previous_episodes],
'reference_time': episode.valid_at,
'edge_types': edge_types_context,
@ -160,14 +158,16 @@ async def extract_edges(
invalid_at = edge_data.get('invalid_at', None)
valid_at_datetime = None
invalid_at_datetime = None
source_node_uuid = node_uuids_by_name_map.get(edge_data.get('source_entity_name', ''), '')
target_node_uuid = node_uuids_by_name_map.get(edge_data.get('target_entity_name', ''), '')
if source_node_uuid == '' or target_node_uuid == '':
source_node_idx = edge_data.get('source_entity_id', -1)
target_node_idx = edge_data.get('target_entity_id', -1)
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
logger.warning(
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_uuid} and target_node_uuid: {target_node_uuid} '
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
)
continue
source_node_uuid = nodes[source_node_idx].uuid
target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
if valid_at:
try:

View File

@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.dedupe_nodes import NodeResolutions
from graphiti_core.prompts.extract_nodes import (
ExtractedEntities,
ExtractedEntity,
@ -241,7 +241,25 @@ async def resolve_extracted_nodes(
]
)
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
existing_nodes_dict: dict[str, EntityNode] = {
node.uuid: node for result in search_results for node in result.nodes
}
existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
existing_nodes_context = (
[
{
**{
'idx': i,
'name': candidate.name,
'entity_types': candidate.labels,
},
**candidate.attributes,
}
for i, candidate in enumerate(existing_nodes)
],
)
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
@ -255,23 +273,13 @@ async def resolve_extracted_nodes(
next((item for item in node.labels if item != 'Entity'), '')
).__doc__
or 'Default Entity Type',
'duplication_candidates': [
{
**{
'idx': j,
'name': candidate.name,
'entity_types': candidate.labels,
},
**candidate.attributes,
}
for j, candidate in enumerate(existing_nodes_lists[i])
],
}
for i, node in enumerate(extracted_nodes)
]
context = {
'extracted_nodes': extracted_nodes_context,
'existing_nodes': existing_nodes_context,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
@ -294,8 +302,8 @@ async def resolve_extracted_nodes(
extracted_node = extracted_nodes[resolution_id]
resolved_node = (
existing_nodes_lists[resolution_id][duplicate_idx]
if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
existing_nodes[duplicate_idx]
if 0 <= duplicate_idx < len(existing_nodes)
else extracted_node
)
@ -309,70 +317,6 @@ async def resolve_extracted_nodes(
return resolved_nodes, uuid_map
async def resolve_extracted_node(
llm_client: LLMClient,
extracted_node: EntityNode,
existing_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_type: BaseModel | None = None,
) -> EntityNode:
start = time()
if len(existing_nodes) == 0:
return extracted_node
# Prepare context for LLM
existing_nodes_context = [
{
**{
'id': i,
'name': node.name,
'entity_types': node.labels,
},
**node.attributes,
}
for i, node in enumerate(existing_nodes)
]
extracted_node_context = {
'name': extracted_node.name,
'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore
}
context = {
'existing_nodes': existing_nodes_context,
'extracted_node': extracted_node_context,
'entity_type_description': entity_type.__doc__
if entity_type is not None
else 'Default Entity Type',
'episode_content': episode.content if episode is not None else '',
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
else [],
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.node(context),
response_model=NodeDuplicate,
model_size=ModelSize.small,
)
duplicate_id: int = llm_response.get('duplicate_node_id', -1)
node = (
existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node
)
node.name = llm_response.get('name', '')
end = time()
logger.debug(
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
)
return node
async def extract_attributes_from_nodes(
clients: GraphitiClients,
nodes: list[EntityNode],

View File

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.12.0pre4"
version = "0.12.0"
authors = [
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },