mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
Edge extraction and Node Deduplication updates (#564)
* update tests * updated fact extraction * optimize node deduplication * linting * Update graphiti_core/utils/maintenance/edge_operations.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
e3f1c679f7
commit
ebee09b335
@ -63,6 +63,10 @@ class Person(BaseModel):
|
|||||||
occupation: str | None = Field(..., description="The person's work occupation")
|
occupation: str | None = Field(..., description="The person's work occupation")
|
||||||
|
|
||||||
|
|
||||||
|
class IsPresidentOf(BaseModel):
|
||||||
|
"""Relationship between a person and the entity they are a president of"""
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
setup_logging()
|
setup_logging()
|
||||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||||
@ -84,6 +88,8 @@ async def main():
|
|||||||
source_description='Podcast Transcript',
|
source_description='Podcast Transcript',
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
entity_types={'Person': Person},
|
entity_types={'Person': Person},
|
||||||
|
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
|
||||||
|
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
|
||||||
previous_episode_uuids=episode_uuids,
|
previous_episode_uuids=episode_uuids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -137,8 +137,12 @@ def nodes(context: dict[str, Any]) -> list[Message]:
|
|||||||
<ENTITIES>
|
<ENTITIES>
|
||||||
{json.dumps(context['extracted_nodes'], indent=2)}
|
{json.dumps(context['extracted_nodes'], indent=2)}
|
||||||
</ENTITIES>
|
</ENTITIES>
|
||||||
|
|
||||||
|
<EXISTING ENTITIES>
|
||||||
|
{json.dumps(context['existing_nodes'], indent=2)}
|
||||||
|
</EXISTING ENTITIES>
|
||||||
|
|
||||||
For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
|
For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
|
||||||
|
|
||||||
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
|
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
|
||||||
|
|
||||||
@ -152,9 +156,9 @@ def nodes(context: dict[str, Any]) -> list[Message]:
|
|||||||
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
|
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
|
||||||
as an integer.
|
as an integer.
|
||||||
|
|
||||||
- If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
|
- If an entity is a duplicate of one of the EXISTING ENTITIES, return the idx of the candidate it is a
|
||||||
duplicate of.
|
duplicate of.
|
||||||
- If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
|
- If an entity is not a duplicate of one of the EXISTING ENTITIES, return the -1 as the duplication_idx
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
@ -24,8 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
|
|||||||
|
|
||||||
class Edge(BaseModel):
|
class Edge(BaseModel):
|
||||||
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
|
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
|
||||||
source_entity_name: str = Field(..., description='The name of the source entity of the fact.')
|
source_entity_id: int = Field(..., description='The id of the source entity of the fact.')
|
||||||
target_entity_name: str = Field(..., description='The name of the target entity of the fact.')
|
target_entity_id: int = Field(..., description='The id of the target entity of the fact.')
|
||||||
fact: str = Field(..., description='')
|
fact: str = Field(..., description='')
|
||||||
valid_at: str | None = Field(
|
valid_at: str | None = Field(
|
||||||
None,
|
None,
|
||||||
@ -77,7 +77,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
|
|||||||
</CURRENT_MESSAGE>
|
</CURRENT_MESSAGE>
|
||||||
|
|
||||||
<ENTITIES>
|
<ENTITIES>
|
||||||
{context['nodes']} # Each has: id, label (e.g., Person, Org), name, aliases
|
{context['nodes']}
|
||||||
</ENTITIES>
|
</ENTITIES>
|
||||||
|
|
||||||
<REFERENCE_TIME>
|
<REFERENCE_TIME>
|
||||||
@ -94,8 +94,9 @@ Only extract facts that:
|
|||||||
- involve two DISTINCT ENTITIES from the ENTITIES list,
|
- involve two DISTINCT ENTITIES from the ENTITIES list,
|
||||||
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
|
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
|
||||||
and can be represented as edges in a knowledge graph.
|
and can be represented as edges in a knowledge graph.
|
||||||
- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
|
- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
|
||||||
could be classified into one of the provided fact types
|
- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
|
||||||
|
of the FACT TYPES
|
||||||
|
|
||||||
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
|
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
|
||||||
|
|
||||||
|
@ -92,8 +92,6 @@ async def extract_edges(
|
|||||||
extract_edges_max_tokens = 16384
|
extract_edges_max_tokens = 16384
|
||||||
llm_client = clients.llm_client
|
llm_client = clients.llm_client
|
||||||
|
|
||||||
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
|
||||||
|
|
||||||
edge_types_context = (
|
edge_types_context = (
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@ -109,7 +107,7 @@ async def extract_edges(
|
|||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
context = {
|
context = {
|
||||||
'episode_content': episode.content,
|
'episode_content': episode.content,
|
||||||
'nodes': [node.name for node in nodes],
|
'nodes': [{'id': idx, 'name': node.name} for idx, node in enumerate(nodes)],
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||||
'reference_time': episode.valid_at,
|
'reference_time': episode.valid_at,
|
||||||
'edge_types': edge_types_context,
|
'edge_types': edge_types_context,
|
||||||
@ -160,14 +158,16 @@ async def extract_edges(
|
|||||||
invalid_at = edge_data.get('invalid_at', None)
|
invalid_at = edge_data.get('invalid_at', None)
|
||||||
valid_at_datetime = None
|
valid_at_datetime = None
|
||||||
invalid_at_datetime = None
|
invalid_at_datetime = None
|
||||||
source_node_uuid = node_uuids_by_name_map.get(edge_data.get('source_entity_name', ''), '')
|
|
||||||
target_node_uuid = node_uuids_by_name_map.get(edge_data.get('target_entity_name', ''), '')
|
|
||||||
|
|
||||||
if source_node_uuid == '' or target_node_uuid == '':
|
source_node_idx = edge_data.get('source_entity_id', -1)
|
||||||
|
target_node_idx = edge_data.get('target_entity_id', -1)
|
||||||
|
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_uuid} and target_node_uuid: {target_node_uuid} '
|
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
source_node_uuid = nodes[source_node_idx].uuid
|
||||||
|
target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
|
||||||
|
|
||||||
if valid_at:
|
if valid_at:
|
||||||
try:
|
try:
|
||||||
|
@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient
|
|||||||
from graphiti_core.llm_client.config import ModelSize
|
from graphiti_core.llm_client.config import ModelSize
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
|
from graphiti_core.prompts.dedupe_nodes import NodeResolutions
|
||||||
from graphiti_core.prompts.extract_nodes import (
|
from graphiti_core.prompts.extract_nodes import (
|
||||||
ExtractedEntities,
|
ExtractedEntities,
|
||||||
ExtractedEntity,
|
ExtractedEntity,
|
||||||
@ -241,7 +241,25 @@ async def resolve_extracted_nodes(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
|
existing_nodes_dict: dict[str, EntityNode] = {
|
||||||
|
node.uuid: node for result in search_results for node in result.nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
|
||||||
|
|
||||||
|
existing_nodes_context = (
|
||||||
|
[
|
||||||
|
{
|
||||||
|
**{
|
||||||
|
'idx': i,
|
||||||
|
'name': candidate.name,
|
||||||
|
'entity_types': candidate.labels,
|
||||||
|
},
|
||||||
|
**candidate.attributes,
|
||||||
|
}
|
||||||
|
for i, candidate in enumerate(existing_nodes)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
|
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
|
||||||
|
|
||||||
@ -255,23 +273,13 @@ async def resolve_extracted_nodes(
|
|||||||
next((item for item in node.labels if item != 'Entity'), '')
|
next((item for item in node.labels if item != 'Entity'), '')
|
||||||
).__doc__
|
).__doc__
|
||||||
or 'Default Entity Type',
|
or 'Default Entity Type',
|
||||||
'duplication_candidates': [
|
|
||||||
{
|
|
||||||
**{
|
|
||||||
'idx': j,
|
|
||||||
'name': candidate.name,
|
|
||||||
'entity_types': candidate.labels,
|
|
||||||
},
|
|
||||||
**candidate.attributes,
|
|
||||||
}
|
|
||||||
for j, candidate in enumerate(existing_nodes_lists[i])
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
for i, node in enumerate(extracted_nodes)
|
for i, node in enumerate(extracted_nodes)
|
||||||
]
|
]
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
'extracted_nodes': extracted_nodes_context,
|
'extracted_nodes': extracted_nodes_context,
|
||||||
|
'existing_nodes': existing_nodes_context,
|
||||||
'episode_content': episode.content if episode is not None else '',
|
'episode_content': episode.content if episode is not None else '',
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes]
|
'previous_episodes': [ep.content for ep in previous_episodes]
|
||||||
if previous_episodes is not None
|
if previous_episodes is not None
|
||||||
@ -294,8 +302,8 @@ async def resolve_extracted_nodes(
|
|||||||
extracted_node = extracted_nodes[resolution_id]
|
extracted_node = extracted_nodes[resolution_id]
|
||||||
|
|
||||||
resolved_node = (
|
resolved_node = (
|
||||||
existing_nodes_lists[resolution_id][duplicate_idx]
|
existing_nodes[duplicate_idx]
|
||||||
if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
|
if 0 <= duplicate_idx < len(existing_nodes)
|
||||||
else extracted_node
|
else extracted_node
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -309,70 +317,6 @@ async def resolve_extracted_nodes(
|
|||||||
return resolved_nodes, uuid_map
|
return resolved_nodes, uuid_map
|
||||||
|
|
||||||
|
|
||||||
async def resolve_extracted_node(
|
|
||||||
llm_client: LLMClient,
|
|
||||||
extracted_node: EntityNode,
|
|
||||||
existing_nodes: list[EntityNode],
|
|
||||||
episode: EpisodicNode | None = None,
|
|
||||||
previous_episodes: list[EpisodicNode] | None = None,
|
|
||||||
entity_type: BaseModel | None = None,
|
|
||||||
) -> EntityNode:
|
|
||||||
start = time()
|
|
||||||
if len(existing_nodes) == 0:
|
|
||||||
return extracted_node
|
|
||||||
|
|
||||||
# Prepare context for LLM
|
|
||||||
existing_nodes_context = [
|
|
||||||
{
|
|
||||||
**{
|
|
||||||
'id': i,
|
|
||||||
'name': node.name,
|
|
||||||
'entity_types': node.labels,
|
|
||||||
},
|
|
||||||
**node.attributes,
|
|
||||||
}
|
|
||||||
for i, node in enumerate(existing_nodes)
|
|
||||||
]
|
|
||||||
|
|
||||||
extracted_node_context = {
|
|
||||||
'name': extracted_node.name,
|
|
||||||
'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore
|
|
||||||
}
|
|
||||||
|
|
||||||
context = {
|
|
||||||
'existing_nodes': existing_nodes_context,
|
|
||||||
'extracted_node': extracted_node_context,
|
|
||||||
'entity_type_description': entity_type.__doc__
|
|
||||||
if entity_type is not None
|
|
||||||
else 'Default Entity Type',
|
|
||||||
'episode_content': episode.content if episode is not None else '',
|
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes]
|
|
||||||
if previous_episodes is not None
|
|
||||||
else [],
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_response = await llm_client.generate_response(
|
|
||||||
prompt_library.dedupe_nodes.node(context),
|
|
||||||
response_model=NodeDuplicate,
|
|
||||||
model_size=ModelSize.small,
|
|
||||||
)
|
|
||||||
|
|
||||||
duplicate_id: int = llm_response.get('duplicate_node_id', -1)
|
|
||||||
|
|
||||||
node = (
|
|
||||||
existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node
|
|
||||||
)
|
|
||||||
|
|
||||||
node.name = llm_response.get('name', '')
|
|
||||||
|
|
||||||
end = time()
|
|
||||||
logger.debug(
|
|
||||||
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
|
|
||||||
)
|
|
||||||
|
|
||||||
return node
|
|
||||||
|
|
||||||
|
|
||||||
async def extract_attributes_from_nodes(
|
async def extract_attributes_from_nodes(
|
||||||
clients: GraphitiClients,
|
clients: GraphitiClients,
|
||||||
nodes: list[EntityNode],
|
nodes: list[EntityNode],
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.12.0pre4"
|
version = "0.12.0"
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
||||||
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
||||||
|
Loading…
x
Reference in New Issue
Block a user