add_episode() refactor (#421)

* temporal updates

* update resolve nodes

* dedupe edge updates

* edge dedupe

* extract attributes

* update dynamic pydantic model

* first pass of extract node attributes

* no errors

* bug fixes

* bug fixes

* prompt updates

* prompt updates

* updates

* updates

* remove unused imports

* update tests based on changes

* remove unused import
This commit is contained in:
Preston Rasmussen 2025-04-30 12:08:52 -04:00 committed by GitHub
parent 4d5408f02a
commit 1193b25fa3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 485 additions and 509 deletions

View File

@ -20,7 +20,7 @@ Fordham is a well-regarded private university in New York City, founded in 1841
There's a very daunting hall of portraits outside of my office. You know, all of these priests going back to 1841,
0 (1m 41s):
Tet, LO's own father was in fact a priest. But while getting his psychology PhD at Fordham, he met his Wouldbe wife, another graduate student, so he left the priesthood. Tanya was born in New York not long before the family moved to New Orleans, so Fordham is in her genes.
Tet, LO's own father was in fact a priest. But while getting his psychology PhD at Fordham, he met his Wouldbe wife, another graduate student, so he left the priesthood. Tania was born in New York not long before the family moved to New Orleans, so Fordham is in her genes.
1 (2m 0s):
A good way to recruit me is they can tell me you exist because of us.

View File

@ -72,7 +72,11 @@ from graphiti_core.utils.maintenance.graph_data_operations import (
build_indices_and_constraints,
retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import extract_nodes, resolve_extracted_nodes
from graphiti_core.utils.maintenance.node_operations import (
extract_attributes_from_nodes,
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
@ -370,15 +374,16 @@ class Graphiti:
extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id),
)
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
extracted_edges, uuid_map
)
edges = resolve_edge_pointers(extracted_edges, uuid_map)
resolved_edges, invalidated_edges = await resolve_extracted_edges(
self.clients,
extracted_edges_with_resolved_pointers,
episode,
previous_episodes,
(resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather(
resolve_extracted_edges(
self.clients,
edges,
),
extract_attributes_from_nodes(
self.clients, nodes, episode, previous_episodes, entity_types
),
)
entity_edges = resolved_edges + invalidated_edges

View File

@ -29,7 +29,7 @@ load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 1))
MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
DEFAULT_PAGE_LIMIT = 20
RUNTIME_QUERY: LiteralString = (

View File

@ -23,10 +23,9 @@ from .models import Message, PromptFunction, PromptVersion
class EdgeDuplicate(BaseModel):
is_duplicate: bool = Field(..., description='true or false')
uuid: str | None = Field(
None,
description="uuid of the existing edge like '5d643020624c42fa9de13f97b1b3fa39' or null",
duplicate_fact_id: int = Field(
...,
description='id of the duplicate fact. If no duplicate facts are found, default to -1.',
)
@ -69,9 +68,8 @@ def edge(context: dict[str, Any]) -> list[Message]:
</NEW EDGE>
Task:
1. If the New Edges represents the same factual information as any edge in Existing Edges, return 'is_duplicate: true' in the
response. Otherwise, return 'is_duplicate: false'
2. If is_duplicate is true, also return the uuid of the existing edge in the response
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact.
If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return -1.
Guidelines:
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.

View File

@ -23,14 +23,9 @@ from .models import Message, PromptFunction, PromptVersion
class NodeDuplicate(BaseModel):
is_duplicate: bool = Field(..., description='true or false')
uuid: str | None = Field(
None,
description="uuid of the existing node like '5d643020624c42fa9de13f97b1b3fa39' or null",
)
name: str = Field(
duplicate_node_id: int = Field(
...,
description="Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)",
description='id of the duplicate node. If no duplicate nodes are found, default to -1.',
)
@ -64,28 +59,20 @@ def node(context: dict[str, Any]) -> list[Message]:
{json.dumps(context['existing_nodes'], indent=2)}
</EXISTING NODES>
Given the above EXISTING NODES and their attributes, MESSAGE, and PREVIOUS MESSAGES. Determine if the NEW NODE extracted from the conversation
Given the above EXISTING NODES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW NODE extracted from the conversation
is a duplicate entity of one of the EXISTING NODES.
<NEW NODE>
{json.dumps(context['extracted_nodes'], indent=2)}
{json.dumps(context['extracted_node'], indent=2)}
</NEW NODE>
Task:
1. If the New Node represents the same entity as any node in Existing Nodes, return 'is_duplicate: true' in the
response. Otherwise, return 'is_duplicate: false'
2. If is_duplicate is true, also return the uuid of the existing node in the response
3. If is_duplicate is true, return a name for the node that is the most complete full name.
If the NEW NODE is a duplicate of any node in EXISTING NODES, set duplicate_node_id to the
id of the EXISTING NODE that is the duplicate. If the NEW NODE is not a duplicate of any of the EXISTING NODES,
duplicate_node_id should be set to -1.
Guidelines:
1. Use both the name and summary of nodes to determine if the entities are duplicates,
1. Use the name, summary, and attributes of nodes to determine if the entities are duplicates,
duplicate nodes may have different names
Respond with a JSON object in the following format:
{{
"is_duplicate": true or false,
"uuid": "uuid of the existing node like 5d643020624c42fa9de13f97b1b3fa39 or null",
"name": "Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)"
}}
""",
),
]

View File

@ -23,10 +23,18 @@ from .models import Message, PromptFunction, PromptVersion
class Edge(BaseModel):
relation_type: str = Field(..., description='RELATION_TYPE_IN_CAPS')
source_entity_name: str = Field(..., description='name of the source entity')
target_entity_name: str = Field(..., description='name of the target entity')
fact: str = Field(..., description='extracted factual information')
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
source_entity_name: str = Field(..., description='The name of the source entity of the fact.')
target_entity_name: str = Field(..., description='The name of the target entity of the fact.')
fact: str = Field(..., description='')
valid_at: str | None = Field(
None,
description='The date and time when the relationship described by the edge fact became true or was established. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)',
)
invalid_at: str | None = Field(
None,
description='The date and time when the relationship described by the edge fact stopped being true or ended. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)',
)
class ExtractedEdges(BaseModel):
@ -51,32 +59,59 @@ def edge(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are an expert fact extractor that extracts fact triples from text.',
content='You are an expert fact extractor that extracts fact triples from text. '
'1. Extracted fact triples should also be extracted with relevant date information.'
'2. Treat the CURRENT TIME as the time the CURRENT MESSAGE was sent. All temporal information should be extracted relative to this time.',
),
Message(
role='user',
content=f"""
<PREVIOUS MESSAGES>
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
</PREVIOUS MESSAGES>
<CURRENT MESSAGE>
{context['episode_content']}
</CURRENT MESSAGE>
<ENTITIES>
{context['nodes']}
</ENTITIES>
{context['custom_prompt']}
<PREVIOUS_MESSAGES>
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
</PREVIOUS_MESSAGES>
Given the above MESSAGES and ENTITIES, extract all facts pertaining to the listed ENTITIES from the CURRENT MESSAGE.
Guidelines:
1. Extract facts only between the provided entities.
2. Each fact should represent a clear relationship between two DISTINCT nodes.
3. The relation_type should be a concise, all-caps description of the fact (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
4. Provide a more detailed fact containing all relevant information.
5. Consider temporal aspects of relationships when relevant.
<CURRENT_MESSAGE>
{context['episode_content']}
</CURRENT_MESSAGE>
<ENTITIES>
{context['nodes']} # Each has: id, label (e.g., Person, Org), name, aliases
</ENTITIES>
<REFERENCE_TIME>
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
</REFERENCE_TIME>
# TASK
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
Only extract facts that:
- involve two DISTINCT ENTITIES from the ENTITIES list,
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
- and can be represented as edges in a knowledge graph.
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
{context['custom_prompt']}
# EXTRACTION RULES
1. Only emit facts where both the subject and object match IDs in ENTITIES.
2. Each fact must involve two **distinct** entities.
3. Use a SCREAMING_SNAKE_CASE string as the `relation_type` (e.g., FOUNDED, WORKS_AT).
4. Do not emit duplicate or semantically redundant facts.
5. The `fact_text` should quote or closely paraphrase the original source sentence(s).
6. Use `REFERENCE_TIME` to resolve vague or relative temporal expressions (e.g., "last week").
7. Do **not** hallucinate or infer temporal bounds from unrelated events.
# DATETIME RULES
- Use ISO 8601 with Z suffix (UTC) (e.g., 2025-04-30T00:00:00Z).
- If the fact is ongoing (present tense), set `valid_at` to REFERENCE_TIME.
- If a change/termination is expressed, set `invalid_at` to the relevant timestamp.
- Leave both fields `null` if no explicit or resolvable time is stated.
- If only a date is mentioned (no time), assume 00:00:00.
- If only a year is mentioned, use January 1st at 00:00:00.
""",
),
]

View File

@ -22,8 +22,16 @@ from pydantic import BaseModel, Field
from .models import Message, PromptFunction, PromptVersion
class ExtractedNodes(BaseModel):
extracted_node_names: list[str] = Field(..., description='Name of the extracted entity')
class ExtractedEntity(BaseModel):
name: str = Field(..., description='Name of the extracted entity')
entity_type_id: int = Field(
description='ID of the classified entity type. '
'Must be one of the provided entity_type_id integers.',
)
class ExtractedEntities(BaseModel):
extracted_entities: list[ExtractedEntity] = Field(..., description='List of extracted entities')
class MissedEntities(BaseModel):
@ -50,6 +58,7 @@ class Prompt(Protocol):
extract_text: PromptVersion
reflexion: PromptVersion
classify_nodes: PromptVersion
extract_attributes: PromptVersion
class Versions(TypedDict):
@ -58,31 +67,49 @@ class Versions(TypedDict):
extract_text: PromptFunction
reflexion: PromptFunction
classify_nodes: PromptFunction
extract_attributes: PromptFunction
def extract_message(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational messages. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational messages.
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""
user_prompt = f"""
<PREVIOUS MESSAGES>
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
</PREVIOUS MESSAGES>
<CURRENT MESSAGE>
{context['episode_content']}
</CURRENT MESSAGE>
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>
Instructions:
You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
1. **Speaker Extraction**: Always extract the speaker (the part before the colon `:` in each dialogue line) as the first entity node.
- If the speaker is mentioned again in the message, treat both mentions as a **single entity**.
2. **Entity Identification**:
- Extract all significant entities, concepts, or actors that are **explicitly or implicitly** mentioned in the CURRENT MESSAGE.
- **Exclude** entities mentioned only in the PREVIOUS MESSAGES (they are for context only).
3. **Entity Classification**:
- Use the descriptions in ENTITY TYPES to classify each extracted entity.
- Assign the appropriate `entity_type_id` for each one.
4. **Exclusions**:
- Do NOT extract entities representing relationships or actions.
- Do NOT extract dates, times, or other temporal informationthese will be handled separately.
5. **Formatting**:
- Be **explicit and unambiguous** in naming entities (e.g., use full names when available).
{context['custom_prompt']}
Given the above conversation, extract entity nodes from the CURRENT MESSAGE that are explicitly or implicitly mentioned:
Guidelines:
1. ALWAYS extract the speaker/actor as the first node. The speaker is the part before the colon in each line of dialogue.
2. Extract other significant entities, concepts, or actors mentioned in the CURRENT MESSAGE.
3. DO NOT create nodes for relationships or actions.
4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
5. Be as explicit as possible in your node names, using full names.
6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
7. Extract preferences as their own nodes
"""
return [
Message(role='system', content=sys_prompt),
@ -92,7 +119,7 @@ Guidelines:
def extract_json(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from JSON.
Your primary task is to identify and extract relevant entities from JSON files"""
Your primary task is to extract and classify relevant entities from JSON files"""
user_prompt = f"""
<SOURCE DESCRIPTION>:
@ -101,10 +128,15 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
<JSON>
{context['episode_content']}
</JSON>
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>
{context['custom_prompt']}
Given the above source description and JSON, extract relevant entity nodes from the provided JSON:
Given the above source description and JSON, extract relevant entities from the provided JSON.
For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
Indicate the classified entity type by providing its entity_type_id.
Guidelines:
1. Always try to extract an entities that the JSON represents. This will often be something like a "name" or "user field
@ -117,17 +149,23 @@ Guidelines:
def extract_text(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the provided text."""
sys_prompt = """You are an AI assistant that extracts entity nodes from text.
Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""
user_prompt = f"""
<TEXT>
{context['episode_content']}
</TEXT>
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>
Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
Indicate the classified entity type by providing its entity_type_id.
{context['custom_prompt']}
Given the above text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned:
Guidelines:
1. Extract significant entities, concepts, or actors mentioned in the conversation.
2. Avoid creating nodes for relationships or actions.
@ -196,10 +234,43 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
]
def extract_attributes(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that extracts entity properties from the provided text.',
),
Message(
role='user',
content=f"""
<MESSAGES>
{json.dumps(context['previous_episodes'], indent=2)}
{json.dumps(context['episode_content'], indent=2)}
</MESSAGES>
Given the above MESSAGES and the following ENTITY, update any of its attributes based on the information provided
in MESSAGES. Use the provided attribute descriptions to better understand how each attribute should be determined.
Guidelines:
1. Do not hallucinate entity property values if they cannot be found in the current context.
2. Only use the provided MESSAGES and ENTITY to set attribute values.
3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES.
Summaries must be no longer than 200 words.
<ENTITY>
{context['node']}
</ENTITY>
""",
),
]
versions: Versions = {
'extract_message': extract_message,
'extract_json': extract_json,
'extract_text': extract_text,
'reflexion': reflexion,
'classify_nodes': classify_nodes,
'extract_attributes': extract_attributes,
}

View File

@ -21,14 +21,10 @@ from pydantic import BaseModel, Field
from .models import Message, PromptFunction, PromptVersion
class InvalidatedEdge(BaseModel):
uuid: str = Field(..., description='The UUID of the edge to be invalidated')
fact: str = Field(..., description='Updated fact of the edge')
class InvalidatedEdges(BaseModel):
invalidated_edges: list[InvalidatedEdge] = Field(
..., description='List of edges that should be invalidated'
contradicted_facts: list[int] = Field(
...,
description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.',
)
@ -78,18 +74,22 @@ def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.',
content='You are an AI assistant that determines which facts contradict each other.',
),
Message(
role='user',
content=f"""
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge.
Based on the provided EXISTING FACTS and a NEW FACT, determine which existing facts the new fact contradicts.
Return a list containing all ids of the facts that are contradicted by the NEW FACT.
If there are no contradicted facts, return an empty list.
Existing Edges:
<EXISTING FACTS>
{context['existing_edges']}
</EXISTING FACTS>
New Edge:
<NEW FACT>
{context['new_edge']}
</NEW FACT>
""",
),
]

View File

@ -341,10 +341,10 @@ async def node_fulltext_search(
query = (
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """
@ -672,21 +672,36 @@ async def get_relevant_nodes(
"""
+ filter_query
+ """
WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
WHERE score > $min_score
WITH node, n, score
ORDER BY score DESC
RETURN node.uuid AS search_node_uuid,
collect({
uuid: n.uuid,
name: n.name,
name_embedding: n.name_embedding,
group_id: n.group_id,
created_at: n.created_at,
summary: n.summary,
labels: labels(n),
attributes: properties(n)
})[..$limit] AS matches
WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
WHERE score > $min_score
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
CALL db.index.fulltext.queryNodes("node_name_and_summary", 'group_id:"' + $group_id + '" AND ' + node.name, {limit: $limit})
YIELD node AS m
WHERE m.group_id = $group_id
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
WITH node,
top_vector_nodes,
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
UNWIND combined_nodes AS combined_node
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
RETURN
node.uuid AS search_node_uuid,
[x IN deduped_nodes | {
uuid: x.uuid,
name: x.name,
name_embedding: x.name_embedding,
group_id: x.group_id,
created_at: x.created_at,
summary: x.summary,
labels: labels(x),
attributes: properties(x)
}] AS matches
"""
)

View File

@ -33,9 +33,8 @@ from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
get_edge_contradictions,
)
@ -100,12 +99,13 @@ async def extract_edges(
'episode_content': episode.content,
'nodes': [node.name for node in nodes],
'previous_episodes': [ep.content for ep in previous_episodes],
'reference_time': episode.valid_at,
'custom_prompt': '',
}
facts_missed = True
reflexion_iterations = 0
while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
while facts_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
llm_response = await llm_client.generate_response(
prompt_library.extract_edges.edge(context),
response_model=ExtractedEdges,
@ -118,7 +118,9 @@ async def extract_edges(
reflexion_iterations += 1
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
reflexion_response = await llm_client.generate_response(
prompt_library.extract_edges.reflexion(context), response_model=MissingFacts
prompt_library.extract_edges.reflexion(context),
response_model=MissingFacts,
max_tokens=extract_edges_max_tokens,
)
missing_facts = reflexion_response.get('missing_facts', [])
@ -134,9 +136,33 @@ async def extract_edges(
end = time()
logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
if len(edges_data) == 0:
return []
# Convert the extracted data into EntityEdge objects
edges = []
for edge_data in edges_data:
# Validate Edge Date information
valid_at = edge_data.get('valid_at', None)
invalid_at = edge_data.get('invalid_at', None)
valid_at_datetime = None
invalid_at_datetime = None
if valid_at:
try:
valid_at_datetime = ensure_utc(
datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
)
except ValueError as e:
logger.warning(f'WARNING: Error parsing valid_at date: {e}. Input: {valid_at}')
if invalid_at:
try:
invalid_at_datetime = ensure_utc(
datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
)
except ValueError as e:
logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
edge = EntityEdge(
source_node_uuid=node_uuids_by_name_map.get(
edge_data.get('source_entity_name', ''), ''
@ -149,8 +175,8 @@ async def extract_edges(
fact=edge_data.get('fact', ''),
episodes=[episode.uuid],
created_at=utc_now(),
valid_at=None,
invalid_at=None,
valid_at=valid_at_datetime,
invalid_at=invalid_at_datetime,
)
edges.append(edge)
logger.debug(
@ -211,25 +237,22 @@ async def dedupe_extracted_edges(
async def resolve_extracted_edges(
clients: GraphitiClients,
extracted_edges: list[EntityEdge],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[list[EntityEdge], list[EntityEdge]]:
driver = clients.driver
llm_client = clients.llm_client
related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges(
driver, extracted_edges, SearchFilters()
search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather(
get_relevant_edges(driver, extracted_edges, SearchFilters()),
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()),
)
related_edges_lists, edge_invalidation_candidates = search_results
logger.debug(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
)
edge_invalidation_candidates: list[list[EntityEdge]] = await get_edge_invalidation_candidates(
driver, extracted_edges, SearchFilters()
)
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
# resolve edges with related edges in the graph and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
await semaphore_gather(
*[
@ -238,11 +261,9 @@ async def resolve_extracted_edges(
extracted_edge,
related_edges,
existing_edges,
current_episode,
previous_episodes,
)
for extracted_edge, related_edges, existing_edges in zip(
extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=False
extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True
)
]
)
@ -265,6 +286,9 @@ async def resolve_extracted_edges(
def resolve_edge_contradictions(
resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
) -> list[EntityEdge]:
if len(invalidation_candidates) == 0:
return []
# Determine which contradictory edges need to be expired
invalidated_edges: list[EntityEdge] = []
for edge in invalidation_candidates:
@ -297,21 +321,15 @@ async def resolve_extracted_edge(
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await semaphore_gather(
resolved_edge, invalidation_candidates = await semaphore_gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
)
now = utc_now()
resolved_edge.valid_at = valid_at if valid_at else resolved_edge.valid_at
resolved_edge.invalid_at = invalid_at if invalid_at else resolved_edge.invalid_at
if invalid_at and not resolved_edge.expired_at:
if resolved_edge.invalid_at and not resolved_edge.expired_at:
resolved_edge.expired_at = now
# Determine if the new_edge needs to be expired
@ -339,16 +357,17 @@ async def resolve_extracted_edge(
async def dedupe_extracted_edge(
llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge]
) -> EntityEdge:
if len(related_edges) == 0:
return extracted_edge
start = time()
# Prepare context for LLM
related_edges_context = [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in related_edges
{'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges)
]
extracted_edge_context = {
'uuid': extracted_edge.uuid,
'name': extracted_edge.name,
'fact': extracted_edge.fact,
}
@ -361,15 +380,13 @@ async def dedupe_extracted_edge(
prompt_library.dedupe_edges.edge(context), response_model=EdgeDuplicate
)
is_duplicate: bool = llm_response.get('is_duplicate', False)
uuid: str | None = llm_response.get('uuid', None)
duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1)
edge = extracted_edge
if is_duplicate:
for existing_edge in related_edges:
if existing_edge.uuid != uuid:
continue
edge = existing_edge
edge = (
related_edges[duplicate_fact_id]
if 0 <= duplicate_fact_id < len(related_edges)
else extracted_edge
)
end = time()
logger.debug(

View File

@ -20,7 +20,7 @@ from time import time
from typing import Any
import pydantic
from pydantic import BaseModel
from pydantic import BaseModel, Field
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
@ -28,8 +28,11 @@ from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
from graphiti_core.prompts.summarize_nodes import Summary
from graphiti_core.prompts.extract_nodes import (
ExtractedEntities,
ExtractedEntity,
MissedEntities,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_relevant_nodes
from graphiti_core.utils.datetime_utils import utc_now
@ -37,66 +40,6 @@ from graphiti_core.utils.datetime_utils import utc_now
logger = logging.getLogger(__name__)
async def extract_message_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
custom_prompt='',
) -> list[str]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'previous_episodes': [ep.content for ep in previous_episodes],
'custom_prompt': custom_prompt,
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_message(context), response_model=ExtractedNodes
)
extracted_node_names = llm_response.get('extracted_node_names', [])
return extracted_node_names
async def extract_text_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
custom_prompt='',
) -> list[str]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'previous_episodes': [ep.content for ep in previous_episodes],
'custom_prompt': custom_prompt,
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_text(context), ExtractedNodes
)
extracted_node_names = llm_response.get('extracted_node_names', [])
return extracted_node_names
async def extract_json_nodes(
llm_client: LLMClient, episode: EpisodicNode, custom_prompt=''
) -> list[str]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'source_description': episode.source_description,
'custom_prompt': custom_prompt,
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_json(context), ExtractedNodes
)
extracted_node_names = llm_response.get('extracted_node_names', [])
return extracted_node_names
async def extract_nodes_reflexion(
llm_client: LLMClient,
episode: EpisodicNode,
@ -127,82 +70,88 @@ async def extract_nodes(
start = time()
llm_client = clients.llm_client
embedder = clients.embedder
extracted_node_names: list[str] = []
llm_response = {}
custom_prompt = ''
entities_missed = True
reflexion_iterations = 0
while entities_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
entity_types_context = [
{
'entity_type_id': 0,
'entity_type_name': 'Entity',
'entity_type_description': 'Default entity classification. Use this entity type if the entity is not one of the other listed types.',
}
]
entity_types_context += (
[
{
'entity_type_id': i + 1,
'entity_type_name': type_name,
'entity_type_description': type_model.__doc__,
}
for i, (type_name, type_model) in enumerate(entity_types.items())
]
if entity_types is not None
else []
)
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'previous_episodes': [ep.content for ep in previous_episodes],
'custom_prompt': custom_prompt,
'entity_types': entity_types_context,
}
while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
if episode.source == EpisodeType.message:
extracted_node_names = await extract_message_nodes(
llm_client, episode, previous_episodes, custom_prompt
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_message(context),
response_model=ExtractedEntities,
)
elif episode.source == EpisodeType.text:
extracted_node_names = await extract_text_nodes(
llm_client, episode, previous_episodes, custom_prompt
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_text(context), response_model=ExtractedEntities
)
elif episode.source == EpisodeType.json:
extracted_node_names = await extract_json_nodes(llm_client, episode, custom_prompt)
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities
)
extracted_entities: list[ExtractedEntity] = [
ExtractedEntity(**entity_types_context)
for entity_types_context in llm_response.get('extracted_entities', [])
]
reflexion_iterations += 1
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
missing_entities = await extract_nodes_reflexion(
llm_client, episode, previous_episodes, extracted_node_names
llm_client,
episode,
previous_episodes,
[entity.name for entity in extracted_entities],
)
entities_missed = len(missing_entities) != 0
custom_prompt = 'The following entities were missed in a previous extraction: '
custom_prompt = 'Make sure that the following entities are extracted: '
for entity in missing_entities:
custom_prompt += f'\n{entity},'
reflexion_iterations += 1
node_classification_context = {
'episode_content': episode.content,
'previous_episodes': [ep.content for ep in previous_episodes],
'extracted_entities': extracted_node_names,
'entity_types': {
type_name: values.model_json_schema().get('description')
for type_name, values in entity_types.items()
}
if entity_types is not None
else {},
}
node_classifications: dict[str, str | None] = {}
if entity_types is not None:
try:
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.classify_nodes(node_classification_context),
response_model=EntityClassification,
)
entity_classifications = llm_response.get('entity_classifications', [])
node_classifications.update(
{
entity_classification.get('name'): entity_classification.get('entity_type')
for entity_classification in entity_classifications
}
)
# catch classification errors and continue if we can't classify
except Exception as e:
logger.exception(e)
end = time()
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
logger.debug(f'Extracted new nodes: {extracted_entities} in {(end - start) * 1000} ms')
# Convert the extracted data into EntityNode objects
extracted_nodes = []
for name in extracted_node_names:
entity_type = node_classifications.get(name)
if entity_types is not None and entity_type not in entity_types:
entity_type = None
labels = (
['Entity']
if entity_type is None or entity_type == 'None' or entity_type == 'null'
else ['Entity', entity_type]
for extracted_entity in extracted_entities:
entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
'entity_type_name'
)
labels: list[str] = list({'Entity', str(entity_type_name)})
new_node = EntityNode(
name=name,
name=extracted_entity.name,
group_id=episode.group_id,
labels=labels,
summary='',
@ -282,29 +231,29 @@ async def resolve_extracted_nodes(
driver, extracted_nodes, SearchFilters()
)
uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = []
results: list[tuple[EntityNode, dict[str, str]]] = list(
await semaphore_gather(
*[
resolve_extracted_node(
llm_client,
extracted_node,
existing_nodes,
episode,
previous_episodes,
entity_types,
resolved_nodes: list[EntityNode] = await semaphore_gather(
*[
resolve_extracted_node(
llm_client,
extracted_node,
existing_nodes,
episode,
previous_episodes,
entity_types.get(
next((item for item in extracted_node.labels if item != 'Entity'), '')
)
for extracted_node, existing_nodes in zip(
extracted_nodes, existing_nodes_lists, strict=False
)
]
)
if entity_types is not None
else None,
)
for extracted_node, existing_nodes in zip(
extracted_nodes, existing_nodes_lists, strict=True
)
]
)
for result in results:
uuid_map.update(result[1])
resolved_nodes.append(result[0])
uuid_map: dict[str, str] = {}
for extracted_node, resolved_node in zip(extracted_nodes, resolved_nodes, strict=True):
uuid_map[extracted_node.uuid] = resolved_node.uuid
logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}')
@ -317,124 +266,151 @@ async def resolve_extracted_node(
existing_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
) -> tuple[EntityNode, dict[str, str]]:
entity_type: BaseModel | None = None,
) -> EntityNode:
start = time()
if len(existing_nodes) == 0:
return extracted_node
# Prepare context for LLM
existing_nodes_context = [
{**{'uuid': node.uuid, 'name': node.name, 'summary': node.summary}, **node.attributes}
for node in existing_nodes
{
**{
'id': i,
'name': node.name,
'entity_types': node.labels,
'summary': node.summary,
},
**node.attributes,
}
for i, node in enumerate(existing_nodes)
]
extracted_node_context = {
'uuid': extracted_node.uuid,
'name': extracted_node.name,
'summary': extracted_node.summary,
'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore
'entity_type_description': entity_type.__doc__
if entity_type is not None
else 'Default Entity Type',
}
context = {
'existing_nodes': existing_nodes_context,
'extracted_nodes': extracted_node_context,
'extracted_node': extracted_node_context,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
else [],
}
summary_context: dict[str, Any] = {
'node_name': extracted_node.name,
'node_summary': extracted_node.summary,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
else [],
}
attributes: list[dict[str, str]] = []
entity_type_classes: tuple[BaseModel, ...] = tuple()
if entity_types is not None: # type: ignore
entity_type_classes = entity_type_classes + tuple(
filter(
lambda x: x is not None, # type: ignore
[entity_types.get(entity_type) for entity_type in extracted_node.labels], # type: ignore
)
)
for entity_type in entity_type_classes:
for field_name, field_info in entity_type.model_fields.items():
attributes.append(
{
'attribute_name': field_name,
'attribute_description': field_info.description or '',
}
)
summary_context['attributes'] = attributes
entity_attributes_model = pydantic.create_model( # type: ignore
'EntityAttributes',
__base__=entity_type_classes + (Summary,), # type: ignore
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
)
llm_response, node_attributes_response = await semaphore_gather(
llm_client.generate_response(
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
),
llm_client.generate_response(
prompt_library.summarize_nodes.summarize_context(summary_context),
response_model=entity_attributes_model,
),
duplicate_id: int = llm_response.get('duplicate_node_id', -1)
node = (
existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node
)
extracted_node.summary = node_attributes_response.get('summary', '')
node_attributes = {
key: value if (value != 'None' or key == 'summary') else None
for key, value in node_attributes_response.items()
}
with suppress(KeyError):
del node_attributes['summary']
extracted_node.attributes.update(node_attributes)
is_duplicate: bool = llm_response.get('is_duplicate', False)
uuid: str | None = llm_response.get('uuid', None)
name = llm_response.get('name', '')
node = extracted_node
uuid_map: dict[str, str] = {}
if is_duplicate:
for existing_node in existing_nodes:
if existing_node.uuid != uuid:
continue
summary_response = await llm_client.generate_response(
prompt_library.summarize_nodes.summarize_pair(
{'node_summaries': [extracted_node.summary, existing_node.summary]}
),
response_model=Summary,
)
node = existing_node
node.name = name
node.summary = summary_response.get('summary', '')
new_attributes = extracted_node.attributes
existing_attributes = existing_node.attributes
for attribute_name, attribute_value in existing_attributes.items():
if new_attributes.get(attribute_name) is None:
new_attributes[attribute_name] = attribute_value
node.attributes = new_attributes
node.labels = list(set(existing_node.labels + extracted_node.labels))
uuid_map[extracted_node.uuid] = existing_node.uuid
end = time()
logger.debug(
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
)
return node, uuid_map
return node
async def extract_attributes_from_nodes(
clients: GraphitiClients,
nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
) -> list[EntityNode]:
llm_client = clients.llm_client
embedder = clients.embedder
updated_nodes: list[EntityNode] = await semaphore_gather(
*[
extract_attributes_from_node(
llm_client,
node,
episode,
previous_episodes,
entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
if entity_types is not None
else None,
)
for node in nodes
]
)
await create_entity_node_embeddings(embedder, updated_nodes)
return updated_nodes
async def extract_attributes_from_node(
llm_client: LLMClient,
node: EntityNode,
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_type: BaseModel | None = None,
) -> EntityNode:
node_context: dict[str, Any] = {
'name': node.name,
'summary': node.summary,
'entity_types': node.labels,
'attributes': node.attributes,
}
attributes_definitions: dict[str, Any] = {
'summary': (
str,
Field(
description='Summary containing the important information about the entity. Under 200 words',
),
),
'name': (
str,
Field(description='Name of the ENTITY'),
),
}
if entity_type is not None:
for field_name, field_info in entity_type.model_fields.items():
attributes_definitions[field_name] = (
field_info.annotation,
Field(description=field_info.description),
)
entity_attributes_model = pydantic.create_model('EntityAttributes', **attributes_definitions)
summary_context: dict[str, Any] = {
'node': node_context,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
else [],
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_attributes(summary_context),
response_model=entity_attributes_model,
)
node.summary = llm_response.get('summary', node.summary)
node.name = llm_response.get('name', node.name)
node_attributes = {key: value for key, value in llm_response.items()}
with suppress(KeyError):
del node_attributes['summary']
del node_attributes['name']
node.attributes.update(node_attributes)
return node
async def dedupe_node_list(

View File

@ -72,12 +72,10 @@ async def get_edge_contradictions(
llm_client: LLMClient, new_edge: EntityEdge, existing_edges: list[EntityEdge]
) -> list[EntityEdge]:
start = time()
existing_edge_map = {edge.uuid: edge for edge in existing_edges}
new_edge_context = {'uuid': new_edge.uuid, 'name': new_edge.name, 'fact': new_edge.fact}
new_edge_context = {'fact': new_edge.fact}
existing_edge_context = [
{'uuid': existing_edge.uuid, 'name': existing_edge.name, 'fact': existing_edge.fact}
for existing_edge in existing_edges
{'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
]
context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context}
@ -86,14 +84,9 @@ async def get_edge_contradictions(
prompt_library.invalidate_edges.v2(context), response_model=InvalidatedEdges
)
contradicted_edge_data = llm_response.get('invalidated_edges', [])
contradicted_facts: list[int] = llm_response.get('contradicted_facts', [])
contradicted_edges: list[EntityEdge] = []
for edge_data in contradicted_edge_data:
if edge_data['uuid'] in existing_edge_map:
contradicted_edge = existing_edge_map[edge_data['uuid']]
contradicted_edge.fact = edge_data['fact']
contradicted_edges.append(contradicted_edge)
contradicted_edges: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts]
end = time()
logger.debug(

View File

@ -156,7 +156,7 @@ async def eval_graph(multi_session_count: int, session_length: int, llm_client=N
baseline_results[user_id],
add_episode_results[user_id],
add_episode_context[user_id],
strict=True,
strict=False,
):
context = {
'baseline': baseline_result,
@ -164,7 +164,6 @@ async def eval_graph(multi_session_count: int, session_length: int, llm_client=N
'message': episodes[0],
'previous_messages': episodes[1:],
}
print(context)
llm_response = await llm_client.generate_response(
prompt_library.eval.eval_add_episode_results(context),

View File

@ -103,16 +103,12 @@ async def test_resolve_extracted_edge_no_changes(
):
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
extract_dates_mock = AsyncMock(return_value=(None, None))
get_contradictions_mock = AsyncMock(return_value=[])
# Patch the function calls
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
get_contradictions_mock,
@ -123,62 +119,14 @@ async def test_resolve_extracted_edge_no_changes(
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
)
assert resolved_edge.uuid == mock_extracted_edge.uuid
assert invalidated_edges == []
dedupe_mock.assert_called_once()
extract_dates_mock.assert_called_once()
get_contradictions_mock.assert_called_once()
@pytest.mark.asyncio
async def test_resolve_extracted_edge_with_dates(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
monkeypatch: MonkeyPatch,
):
valid_at = datetime.now(timezone.utc) - timedelta(days=1)
invalid_at = datetime.now(timezone.utc) + timedelta(days=1)
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
extract_dates_mock = AsyncMock(return_value=(valid_at, invalid_at))
get_contradictions_mock = AsyncMock(return_value=[])
# Patch the function calls
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
get_contradictions_mock,
)
resolved_edge, invalidated_edges = await resolve_extracted_edge(
mock_llm_client,
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
)
assert resolved_edge.valid_at == valid_at
assert resolved_edge.invalid_at == invalid_at
assert resolved_edge.expired_at is not None
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_with_invalidation(
mock_llm_client,
@ -206,16 +154,12 @@ async def test_resolve_extracted_edge_with_invalidation(
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
extract_dates_mock = AsyncMock(return_value=(None, None))
get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate])
# Patch the function calls
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions',
get_contradictions_mock,
@ -226,8 +170,6 @@ async def test_resolve_extracted_edge_with_invalidation(
mock_extracted_edge,
mock_related_edges,
mock_existing_edges,
mock_current_episode,
mock_previous_episodes,
)
assert resolved_edge.uuid == mock_extracted_edge.uuid

View File

@ -25,7 +25,6 @@ from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
get_edge_contradictions,
)
@ -265,67 +264,6 @@ async def test_invalidate_edges_partial_update():
assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
now = utc_now()
previous_episodes = [
EpisodicNode(
name='Previous Episode 1',
content='Bob: I work at XYZ company',
created_at=now - timedelta(days=2),
valid_at=now - timedelta(days=2),
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
),
EpisodicNode(
name='Previous Episode 2',
content="Alice: That's really cool!",
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
),
]
episode = EpisodicNode(
name='Previous Episode',
content='Bob: It was cool, but I no longer work at company XYZ',
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
)
return episode, previous_episodes
@pytest.mark.asyncio
@pytest.mark.integration
async def test_extract_edge_dates():
episode, previous_episodes = create_data_for_temporal_extraction()
# Create a new edge that partially updates an existing one
new_edge = EntityEdge(
uuid='e9',
source_node_uuid='2',
target_node_uuid='4',
name='LEFT_JOB',
fact='Bob no longer works at Company XYZ',
group_id='1',
created_at=utc_now(),
)
valid_at, invalid_at = await extract_edge_dates(
setup_llm_client(), new_edge, episode, previous_episodes
)
assert valid_at == episode.valid_at
assert invalid_at is None
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])