diff --git a/examples/podcast/podcast_transcript.txt b/examples/podcast/podcast_transcript.txt index c73b6ed..2d78bb1 100644 --- a/examples/podcast/podcast_transcript.txt +++ b/examples/podcast/podcast_transcript.txt @@ -20,7 +20,7 @@ Fordham is a well-regarded private university in New York City, founded in 1841 There's a very daunting hall of portraits outside of my office. You know, all of these priests going back to 1841, 0 (1m 41s): -Tet, LO's own father was in fact a priest. But while getting his psychology PhD at Fordham, he met his Wouldbe wife, another graduate student, so he left the priesthood. Tanya was born in New York not long before the family moved to New Orleans, so Fordham is in her genes. +Tet, LO's own father was in fact a priest. But while getting his psychology PhD at Fordham, he met his Wouldbe wife, another graduate student, so he left the priesthood. Tania was born in New York not long before the family moved to New Orleans, so Fordham is in her genes. 1 (2m 0s): A good way to recruit me is they can tell me you exist because of us. diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 3dcb691..eb9924e 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -72,7 +72,11 @@ from graphiti_core.utils.maintenance.graph_data_operations import ( build_indices_and_constraints, retrieve_episodes, ) -from graphiti_core.utils.maintenance.node_operations import extract_nodes, resolve_extracted_nodes +from graphiti_core.utils.maintenance.node_operations import ( + extract_attributes_from_nodes, + extract_nodes, + resolve_extracted_nodes, +) from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types @@ -370,15 +374,16 @@ class Graphiti: extract_edges(self.clients, episode, extracted_nodes, previous_episodes, group_id), ) - extracted_edges_with_resolved_pointers = resolve_edge_pointers( - extracted_edges, uuid_map - ) + edges = resolve_edge_pointers(extracted_edges, uuid_map) - resolved_edges, invalidated_edges = await resolve_extracted_edges( - self.clients, - extracted_edges_with_resolved_pointers, - episode, - previous_episodes, + (resolved_edges, invalidated_edges), hydrated_nodes = await semaphore_gather( + resolve_extracted_edges( + self.clients, + edges, + ), + extract_attributes_from_nodes( + self.clients, nodes, episode, previous_episodes, entity_types + ), ) entity_edges = resolved_edges + invalidated_edges diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 4381020..aac3228 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -29,7 +29,7 @@ load_dotenv() DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None) USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False)) SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20)) -MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 1)) +MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0)) DEFAULT_PAGE_LIMIT = 20 RUNTIME_QUERY: LiteralString = ( diff --git a/graphiti_core/prompts/dedupe_edges.py b/graphiti_core/prompts/dedupe_edges.py index 4ba4f37..5354f3c 100644 --- a/graphiti_core/prompts/dedupe_edges.py +++ b/graphiti_core/prompts/dedupe_edges.py @@ -23,10 +23,9 @@ from .models import Message, PromptFunction, PromptVersion class EdgeDuplicate(BaseModel): - is_duplicate: bool = Field(..., description='true or false') - uuid: str | None = Field( - None, - description="uuid of the existing edge like '5d643020624c42fa9de13f97b1b3fa39' or null", + duplicate_fact_id: int = Field( + ..., + description='id of the duplicate fact. If no duplicate facts are found, default to -1.', ) @@ -69,9 +68,8 @@ def edge(context: dict[str, Any]) -> list[Message]: Task: - 1. If the New Edges represents the same factual information as any edge in Existing Edges, return 'is_duplicate: true' in the - response. Otherwise, return 'is_duplicate: false' - 2. If is_duplicate is true, also return the uuid of the existing edge in the response + If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact. + If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return -1. Guidelines: 1. The facts do not need to be completely identical to be duplicates, they just need to express the same information. diff --git a/graphiti_core/prompts/dedupe_nodes.py b/graphiti_core/prompts/dedupe_nodes.py index 0d6e0fd..5a870e0 100644 --- a/graphiti_core/prompts/dedupe_nodes.py +++ b/graphiti_core/prompts/dedupe_nodes.py @@ -23,14 +23,9 @@ from .models import Message, PromptFunction, PromptVersion class NodeDuplicate(BaseModel): - is_duplicate: bool = Field(..., description='true or false') - uuid: str | None = Field( - None, - description="uuid of the existing node like '5d643020624c42fa9de13f97b1b3fa39' or null", - ) - name: str = Field( + duplicate_node_id: int = Field( ..., - description="Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)", + description='id of the duplicate node. If no duplicate nodes are found, default to -1.', ) @@ -64,28 +59,20 @@ def node(context: dict[str, Any]) -> list[Message]: {json.dumps(context['existing_nodes'], indent=2)} - Given the above EXISTING NODES and their attributes, MESSAGE, and PREVIOUS MESSAGES. Determine if the NEW NODE extracted from the conversation + Given the above EXISTING NODES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW NODE extracted from the conversation is a duplicate entity of one of the EXISTING NODES. - {json.dumps(context['extracted_nodes'], indent=2)} + {json.dumps(context['extracted_node'], indent=2)} Task: - 1. If the New Node represents the same entity as any node in Existing Nodes, return 'is_duplicate: true' in the - response. Otherwise, return 'is_duplicate: false' - 2. If is_duplicate is true, also return the uuid of the existing node in the response - 3. If is_duplicate is true, return a name for the node that is the most complete full name. + If the NEW NODE is a duplicate of any node in EXISTING NODES, set duplicate_node_id to the + id of the EXISTING NODE that is the duplicate. If the NEW NODE is not a duplicate of any of the EXISTING NODES, + duplicate_node_id should be set to -1. Guidelines: - 1. Use both the name and summary of nodes to determine if the entities are duplicates, + 1. Use the name, summary, and attributes of nodes to determine if the entities are duplicates, duplicate nodes may have different names - - Respond with a JSON object in the following format: - {{ - "is_duplicate": true or false, - "uuid": "uuid of the existing node like 5d643020624c42fa9de13f97b1b3fa39 or null", - "name": "Updated name of the new node (use the best name between the new node's name, an existing duplicate name, or a combination of both)" - }} """, ), ] diff --git a/graphiti_core/prompts/extract_edges.py b/graphiti_core/prompts/extract_edges.py index 9bef859..e7f41cd 100644 --- a/graphiti_core/prompts/extract_edges.py +++ b/graphiti_core/prompts/extract_edges.py @@ -23,10 +23,18 @@ from .models import Message, PromptFunction, PromptVersion class Edge(BaseModel): - relation_type: str = Field(..., description='RELATION_TYPE_IN_CAPS') - source_entity_name: str = Field(..., description='name of the source entity') - target_entity_name: str = Field(..., description='name of the target entity') - fact: str = Field(..., description='extracted factual information') + relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE') + source_entity_name: str = Field(..., description='The name of the source entity of the fact.') + target_entity_name: str = Field(..., description='The name of the target entity of the fact.') + fact: str = Field(..., description='') + valid_at: str | None = Field( + None, + description='The date and time when the relationship described by the edge fact became true or was established. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)', + ) + invalid_at: str | None = Field( + None, + description='The date and time when the relationship described by the edge fact stopped being true or ended. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)', + ) class ExtractedEdges(BaseModel): @@ -51,32 +59,59 @@ def edge(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', - content='You are an expert fact extractor that extracts fact triples from text.', + content='You are an expert fact extractor that extracts fact triples from text. ' + '1. Extracted fact triples should also be extracted with relevant date information.' + '2. Treat the CURRENT TIME as the time the CURRENT MESSAGE was sent. All temporal information should be extracted relative to this time.', ), Message( role='user', content=f""" - - {json.dumps([ep for ep in context['previous_episodes']], indent=2)} - - - {context['episode_content']} - - - - {context['nodes']} - - - {context['custom_prompt']} + +{json.dumps([ep for ep in context['previous_episodes']], indent=2)} + - Given the above MESSAGES and ENTITIES, extract all facts pertaining to the listed ENTITIES from the CURRENT MESSAGE. - - Guidelines: - 1. Extract facts only between the provided entities. - 2. Each fact should represent a clear relationship between two DISTINCT nodes. - 3. The relation_type should be a concise, all-caps description of the fact (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR). - 4. Provide a more detailed fact containing all relevant information. - 5. Consider temporal aspects of relationships when relevant. + +{context['episode_content']} + + + +{context['nodes']} # Each has: id, label (e.g., Person, Org), name, aliases + + + +{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions + + +# TASK +Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE. +Only extract facts that: +- involve two DISTINCT ENTITIES from the ENTITIES list, +- are clearly stated or unambiguously implied in the CURRENT MESSAGE, +- and can be represented as edges in a knowledge graph. + +You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity. + + +{context['custom_prompt']} + +# EXTRACTION RULES + +1. Only emit facts where both the subject and object match IDs in ENTITIES. +2. Each fact must involve two **distinct** entities. +3. Use a SCREAMING_SNAKE_CASE string as the `relation_type` (e.g., FOUNDED, WORKS_AT). +4. Do not emit duplicate or semantically redundant facts. +5. The `fact_text` should quote or closely paraphrase the original source sentence(s). +6. Use `REFERENCE_TIME` to resolve vague or relative temporal expressions (e.g., "last week"). +7. Do **not** hallucinate or infer temporal bounds from unrelated events. + +# DATETIME RULES + +- Use ISO 8601 with “Z” suffix (UTC) (e.g., 2025-04-30T00:00:00Z). +- If the fact is ongoing (present tense), set `valid_at` to REFERENCE_TIME. +- If a change/termination is expressed, set `invalid_at` to the relevant timestamp. +- Leave both fields `null` if no explicit or resolvable time is stated. +- If only a date is mentioned (no time), assume 00:00:00. +- If only a year is mentioned, use January 1st at 00:00:00. """, ), ] diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index a63badc..a4824ec 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -22,8 +22,16 @@ from pydantic import BaseModel, Field from .models import Message, PromptFunction, PromptVersion -class ExtractedNodes(BaseModel): - extracted_node_names: list[str] = Field(..., description='Name of the extracted entity') +class ExtractedEntity(BaseModel): + name: str = Field(..., description='Name of the extracted entity') + entity_type_id: int = Field( + description='ID of the classified entity type. ' + 'Must be one of the provided entity_type_id integers.', + ) + + +class ExtractedEntities(BaseModel): + extracted_entities: list[ExtractedEntity] = Field(..., description='List of extracted entities') class MissedEntities(BaseModel): @@ -50,6 +58,7 @@ class Prompt(Protocol): extract_text: PromptVersion reflexion: PromptVersion classify_nodes: PromptVersion + extract_attributes: PromptVersion class Versions(TypedDict): @@ -58,31 +67,49 @@ class Versions(TypedDict): extract_text: PromptFunction reflexion: PromptFunction classify_nodes: PromptFunction + extract_attributes: PromptFunction def extract_message(context: dict[str, Any]) -> list[Message]: - sys_prompt = """You are an AI assistant that extracts entity nodes from conversational messages. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation.""" + sys_prompt = """You are an AI assistant that extracts entity nodes from conversational messages. + Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation.""" user_prompt = f""" {json.dumps([ep for ep in context['previous_episodes']], indent=2)} + {context['episode_content']} + +{context['entity_types']} + + +Instructions: + +You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE. + +1. **Speaker Extraction**: Always extract the speaker (the part before the colon `:` in each dialogue line) as the first entity node. + - If the speaker is mentioned again in the message, treat both mentions as a **single entity**. + +2. **Entity Identification**: + - Extract all significant entities, concepts, or actors that are **explicitly or implicitly** mentioned in the CURRENT MESSAGE. + - **Exclude** entities mentioned only in the PREVIOUS MESSAGES (they are for context only). + +3. **Entity Classification**: + - Use the descriptions in ENTITY TYPES to classify each extracted entity. + - Assign the appropriate `entity_type_id` for each one. + +4. **Exclusions**: + - Do NOT extract entities representing relationships or actions. + - Do NOT extract dates, times, or other temporal information—these will be handled separately. + +5. **Formatting**: + - Be **explicit and unambiguous** in naming entities (e.g., use full names when available). + {context['custom_prompt']} - -Given the above conversation, extract entity nodes from the CURRENT MESSAGE that are explicitly or implicitly mentioned: - -Guidelines: -1. ALWAYS extract the speaker/actor as the first node. The speaker is the part before the colon in each line of dialogue. -2. Extract other significant entities, concepts, or actors mentioned in the CURRENT MESSAGE. -3. DO NOT create nodes for relationships or actions. -4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later). -5. Be as explicit as possible in your node names, using full names. -6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context. -7. Extract preferences as their own nodes """ return [ Message(role='system', content=sys_prompt), @@ -92,7 +119,7 @@ Guidelines: def extract_json(context: dict[str, Any]) -> list[Message]: sys_prompt = """You are an AI assistant that extracts entity nodes from JSON. - Your primary task is to identify and extract relevant entities from JSON files""" + Your primary task is to extract and classify relevant entities from JSON files""" user_prompt = f""" : @@ -101,10 +128,15 @@ def extract_json(context: dict[str, Any]) -> list[Message]: {context['episode_content']} + +{context['entity_types']} + {context['custom_prompt']} -Given the above source description and JSON, extract relevant entity nodes from the provided JSON: +Given the above source description and JSON, extract relevant entities from the provided JSON. +For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions. +Indicate the classified entity type by providing its entity_type_id. Guidelines: 1. Always try to extract an entities that the JSON represents. This will often be something like a "name" or "user field @@ -117,17 +149,23 @@ Guidelines: def extract_text(context: dict[str, Any]) -> list[Message]: - sys_prompt = """You are an AI assistant that extracts entity nodes from text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the provided text.""" + sys_prompt = """You are an AI assistant that extracts entity nodes from text. + Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text.""" user_prompt = f""" {context['episode_content']} + +{context['entity_types']} + + +Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned. +For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions. +Indicate the classified entity type by providing its entity_type_id. {context['custom_prompt']} -Given the above text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned: - Guidelines: 1. Extract significant entities, concepts, or actors mentioned in the conversation. 2. Avoid creating nodes for relationships or actions. @@ -196,10 +234,43 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]: ] +def extract_attributes(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are a helpful assistant that extracts entity properties from the provided text.', + ), + Message( + role='user', + content=f""" + + + {json.dumps(context['previous_episodes'], indent=2)} + {json.dumps(context['episode_content'], indent=2)} + + + Given the above MESSAGES and the following ENTITY, update any of its attributes based on the information provided + in MESSAGES. Use the provided attribute descriptions to better understand how each attribute should be determined. + + Guidelines: + 1. Do not hallucinate entity property values if they cannot be found in the current context. + 2. Only use the provided MESSAGES and ENTITY to set attribute values. + 3. The summary attribute represents a summary of the ENTITY, and should be updated with new information about the Entity from the MESSAGES. + Summaries must be no longer than 200 words. + + + {context['node']} + + """, + ), + ] + + versions: Versions = { 'extract_message': extract_message, 'extract_json': extract_json, 'extract_text': extract_text, 'reflexion': reflexion, 'classify_nodes': classify_nodes, + 'extract_attributes': extract_attributes, } diff --git a/graphiti_core/prompts/invalidate_edges.py b/graphiti_core/prompts/invalidate_edges.py index 1647396..f30048a 100644 --- a/graphiti_core/prompts/invalidate_edges.py +++ b/graphiti_core/prompts/invalidate_edges.py @@ -21,14 +21,10 @@ from pydantic import BaseModel, Field from .models import Message, PromptFunction, PromptVersion -class InvalidatedEdge(BaseModel): - uuid: str = Field(..., description='The UUID of the edge to be invalidated') - fact: str = Field(..., description='Updated fact of the edge') - - class InvalidatedEdges(BaseModel): - invalidated_edges: list[InvalidatedEdge] = Field( - ..., description='List of edges that should be invalidated' + contradicted_facts: list[int] = Field( + ..., + description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.', ) @@ -78,18 +74,22 @@ def v2(context: dict[str, Any]) -> list[Message]: return [ Message( role='system', - content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.', + content='You are an AI assistant that determines which facts contradict each other.', ), Message( role='user', content=f""" - Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to invalidations with the New Edge. + Based on the provided EXISTING FACTS and a NEW FACT, determine which existing facts the new fact contradicts. + Return a list containing all ids of the facts that are contradicted by the NEW FACT. + If there are no contradicted facts, return an empty list. - Existing Edges: + {context['existing_edges']} + - New Edge: + {context['new_edge']} + """, ), ] diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 56f87d9..ac7f8d1 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -341,10 +341,10 @@ async def node_fulltext_search( query = ( """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS n, score - WHERE n:Entity - """ + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) + YIELD node AS n, score + WHERE n:Entity + """ + filter_query + ENTITY_NODE_RETURN + """ @@ -672,21 +672,36 @@ async def get_relevant_nodes( """ + filter_query + """ - WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score - WHERE score > $min_score - WITH node, n, score - ORDER BY score DESC - RETURN node.uuid AS search_node_uuid, - collect({ - uuid: n.uuid, - name: n.name, - name_embedding: n.name_embedding, - group_id: n.group_id, - created_at: n.created_at, - summary: n.summary, - labels: labels(n), - attributes: properties(n) - })[..$limit] AS matches + WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score + WHERE score > $min_score + WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids + + CALL db.index.fulltext.queryNodes("node_name_and_summary", 'group_id:"' + $group_id + '" AND ' + node.name, {limit: $limit}) + YIELD node AS m + WHERE m.group_id = $group_id + WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes + + WITH node, + top_vector_nodes, + [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes + + WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes + + UNWIND combined_nodes AS combined_node + WITH node, collect(DISTINCT combined_node) AS deduped_nodes + + RETURN + node.uuid AS search_node_uuid, + [x IN deduped_nodes | { + uuid: x.uuid, + name: x.name, + name_embedding: x.name_embedding, + group_id: x.group_id, + created_at: x.created_at, + summary: x.summary, + labels: labels(x), + attributes: properties(x) + }] AS matches """ ) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 1ce1e16..0b8206a 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -33,9 +33,8 @@ from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges -from graphiti_core.utils.datetime_utils import utc_now +from graphiti_core.utils.datetime_utils import ensure_utc, utc_now from graphiti_core.utils.maintenance.temporal_operations import ( - extract_edge_dates, get_edge_contradictions, ) @@ -100,12 +99,13 @@ async def extract_edges( 'episode_content': episode.content, 'nodes': [node.name for node in nodes], 'previous_episodes': [ep.content for ep in previous_episodes], + 'reference_time': episode.valid_at, 'custom_prompt': '', } facts_missed = True reflexion_iterations = 0 - while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS: + while facts_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS: llm_response = await llm_client.generate_response( prompt_library.extract_edges.edge(context), response_model=ExtractedEdges, @@ -118,7 +118,9 @@ async def extract_edges( reflexion_iterations += 1 if reflexion_iterations < MAX_REFLEXION_ITERATIONS: reflexion_response = await llm_client.generate_response( - prompt_library.extract_edges.reflexion(context), response_model=MissingFacts + prompt_library.extract_edges.reflexion(context), + response_model=MissingFacts, + max_tokens=extract_edges_max_tokens, ) missing_facts = reflexion_response.get('missing_facts', []) @@ -134,9 +136,33 @@ async def extract_edges( end = time() logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms') + if len(edges_data) == 0: + return [] + # Convert the extracted data into EntityEdge objects edges = [] for edge_data in edges_data: + # Validate Edge Date information + valid_at = edge_data.get('valid_at', None) + invalid_at = edge_data.get('invalid_at', None) + valid_at_datetime = None + invalid_at_datetime = None + + if valid_at: + try: + valid_at_datetime = ensure_utc( + datetime.fromisoformat(valid_at.replace('Z', '+00:00')) + ) + except ValueError as e: + logger.warning(f'WARNING: Error parsing valid_at date: {e}. Input: {valid_at}') + + if invalid_at: + try: + invalid_at_datetime = ensure_utc( + datetime.fromisoformat(invalid_at.replace('Z', '+00:00')) + ) + except ValueError as e: + logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}') edge = EntityEdge( source_node_uuid=node_uuids_by_name_map.get( edge_data.get('source_entity_name', ''), '' @@ -149,8 +175,8 @@ async def extract_edges( fact=edge_data.get('fact', ''), episodes=[episode.uuid], created_at=utc_now(), - valid_at=None, - invalid_at=None, + valid_at=valid_at_datetime, + invalid_at=invalid_at_datetime, ) edges.append(edge) logger.debug( @@ -211,25 +237,22 @@ async def dedupe_extracted_edges( async def resolve_extracted_edges( clients: GraphitiClients, extracted_edges: list[EntityEdge], - current_episode: EpisodicNode, - previous_episodes: list[EpisodicNode], ) -> tuple[list[EntityEdge], list[EntityEdge]]: driver = clients.driver llm_client = clients.llm_client - related_edges_lists: list[list[EntityEdge]] = await get_relevant_edges( - driver, extracted_edges, SearchFilters() + search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather( + get_relevant_edges(driver, extracted_edges, SearchFilters()), + get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()), ) + related_edges_lists, edge_invalidation_candidates = search_results + logger.debug( f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}' ) - edge_invalidation_candidates: list[list[EntityEdge]] = await get_edge_invalidation_candidates( - driver, extracted_edges, SearchFilters() - ) - - # resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates + # resolve edges with related edges in the graph and find invalidation candidates results: list[tuple[EntityEdge, list[EntityEdge]]] = list( await semaphore_gather( *[ @@ -238,11 +261,9 @@ async def resolve_extracted_edges( extracted_edge, related_edges, existing_edges, - current_episode, - previous_episodes, ) for extracted_edge, related_edges, existing_edges in zip( - extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=False + extracted_edges, related_edges_lists, edge_invalidation_candidates, strict=True ) ] ) @@ -265,6 +286,9 @@ async def resolve_extracted_edges( def resolve_edge_contradictions( resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge] ) -> list[EntityEdge]: + if len(invalidation_candidates) == 0: + return [] + # Determine which contradictory edges need to be expired invalidated_edges: list[EntityEdge] = [] for edge in invalidation_candidates: @@ -297,21 +321,15 @@ async def resolve_extracted_edge( extracted_edge: EntityEdge, related_edges: list[EntityEdge], existing_edges: list[EntityEdge], - current_episode: EpisodicNode, - previous_episodes: list[EpisodicNode], ) -> tuple[EntityEdge, list[EntityEdge]]: - resolved_edge, (valid_at, invalid_at), invalidation_candidates = await semaphore_gather( + resolved_edge, invalidation_candidates = await semaphore_gather( dedupe_extracted_edge(llm_client, extracted_edge, related_edges), - extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes), get_edge_contradictions(llm_client, extracted_edge, existing_edges), ) now = utc_now() - resolved_edge.valid_at = valid_at if valid_at else resolved_edge.valid_at - resolved_edge.invalid_at = invalid_at if invalid_at else resolved_edge.invalid_at - - if invalid_at and not resolved_edge.expired_at: + if resolved_edge.invalid_at and not resolved_edge.expired_at: resolved_edge.expired_at = now # Determine if the new_edge needs to be expired @@ -339,16 +357,17 @@ async def resolve_extracted_edge( async def dedupe_extracted_edge( llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge] ) -> EntityEdge: + if len(related_edges) == 0: + return extracted_edge + start = time() # Prepare context for LLM related_edges_context = [ - {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in related_edges + {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges) ] extracted_edge_context = { - 'uuid': extracted_edge.uuid, - 'name': extracted_edge.name, 'fact': extracted_edge.fact, } @@ -361,15 +380,13 @@ async def dedupe_extracted_edge( prompt_library.dedupe_edges.edge(context), response_model=EdgeDuplicate ) - is_duplicate: bool = llm_response.get('is_duplicate', False) - uuid: str | None = llm_response.get('uuid', None) + duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1) - edge = extracted_edge - if is_duplicate: - for existing_edge in related_edges: - if existing_edge.uuid != uuid: - continue - edge = existing_edge + edge = ( + related_edges[duplicate_fact_id] + if 0 <= duplicate_fact_id < len(related_edges) + else extracted_edge + ) end = time() logger.debug( diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 255aae2..9efe771 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -20,7 +20,7 @@ from time import time from typing import Any import pydantic -from pydantic import BaseModel +from pydantic import BaseModel, Field from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather @@ -28,8 +28,11 @@ from graphiti_core.llm_client import LLMClient from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_nodes import NodeDuplicate -from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities -from graphiti_core.prompts.summarize_nodes import Summary +from graphiti_core.prompts.extract_nodes import ( + ExtractedEntities, + ExtractedEntity, + MissedEntities, +) from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import get_relevant_nodes from graphiti_core.utils.datetime_utils import utc_now @@ -37,66 +40,6 @@ from graphiti_core.utils.datetime_utils import utc_now logger = logging.getLogger(__name__) -async def extract_message_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], - custom_prompt='', -) -> list[str]: - # Prepare context for LLM - context = { - 'episode_content': episode.content, - 'episode_timestamp': episode.valid_at.isoformat(), - 'previous_episodes': [ep.content for ep in previous_episodes], - 'custom_prompt': custom_prompt, - } - - llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_message(context), response_model=ExtractedNodes - ) - extracted_node_names = llm_response.get('extracted_node_names', []) - return extracted_node_names - - -async def extract_text_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], - custom_prompt='', -) -> list[str]: - # Prepare context for LLM - context = { - 'episode_content': episode.content, - 'episode_timestamp': episode.valid_at.isoformat(), - 'previous_episodes': [ep.content for ep in previous_episodes], - 'custom_prompt': custom_prompt, - } - - llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_text(context), ExtractedNodes - ) - extracted_node_names = llm_response.get('extracted_node_names', []) - return extracted_node_names - - -async def extract_json_nodes( - llm_client: LLMClient, episode: EpisodicNode, custom_prompt='' -) -> list[str]: - # Prepare context for LLM - context = { - 'episode_content': episode.content, - 'episode_timestamp': episode.valid_at.isoformat(), - 'source_description': episode.source_description, - 'custom_prompt': custom_prompt, - } - - llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.extract_json(context), ExtractedNodes - ) - extracted_node_names = llm_response.get('extracted_node_names', []) - return extracted_node_names - - async def extract_nodes_reflexion( llm_client: LLMClient, episode: EpisodicNode, @@ -127,82 +70,88 @@ async def extract_nodes( start = time() llm_client = clients.llm_client embedder = clients.embedder - extracted_node_names: list[str] = [] + llm_response = {} custom_prompt = '' entities_missed = True reflexion_iterations = 0 - while entities_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS: + + entity_types_context = [ + { + 'entity_type_id': 0, + 'entity_type_name': 'Entity', + 'entity_type_description': 'Default entity classification. Use this entity type if the entity is not one of the other listed types.', + } + ] + + entity_types_context += ( + [ + { + 'entity_type_id': i + 1, + 'entity_type_name': type_name, + 'entity_type_description': type_model.__doc__, + } + for i, (type_name, type_model) in enumerate(entity_types.items()) + ] + if entity_types is not None + else [] + ) + + context = { + 'episode_content': episode.content, + 'episode_timestamp': episode.valid_at.isoformat(), + 'previous_episodes': [ep.content for ep in previous_episodes], + 'custom_prompt': custom_prompt, + 'entity_types': entity_types_context, + } + + while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS: if episode.source == EpisodeType.message: - extracted_node_names = await extract_message_nodes( - llm_client, episode, previous_episodes, custom_prompt + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.extract_message(context), + response_model=ExtractedEntities, ) elif episode.source == EpisodeType.text: - extracted_node_names = await extract_text_nodes( - llm_client, episode, previous_episodes, custom_prompt + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.extract_text(context), response_model=ExtractedEntities ) elif episode.source == EpisodeType.json: - extracted_node_names = await extract_json_nodes(llm_client, episode, custom_prompt) + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.extract_json(context), response_model=ExtractedEntities + ) + extracted_entities: list[ExtractedEntity] = [ + ExtractedEntity(**entity_types_context) + for entity_types_context in llm_response.get('extracted_entities', []) + ] + + reflexion_iterations += 1 if reflexion_iterations < MAX_REFLEXION_ITERATIONS: missing_entities = await extract_nodes_reflexion( - llm_client, episode, previous_episodes, extracted_node_names + llm_client, + episode, + previous_episodes, + [entity.name for entity in extracted_entities], ) entities_missed = len(missing_entities) != 0 - custom_prompt = 'The following entities were missed in a previous extraction: ' + custom_prompt = 'Make sure that the following entities are extracted: ' for entity in missing_entities: custom_prompt += f'\n{entity},' - reflexion_iterations += 1 - - node_classification_context = { - 'episode_content': episode.content, - 'previous_episodes': [ep.content for ep in previous_episodes], - 'extracted_entities': extracted_node_names, - 'entity_types': { - type_name: values.model_json_schema().get('description') - for type_name, values in entity_types.items() - } - if entity_types is not None - else {}, - } - - node_classifications: dict[str, str | None] = {} - - if entity_types is not None: - try: - llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.classify_nodes(node_classification_context), - response_model=EntityClassification, - ) - entity_classifications = llm_response.get('entity_classifications', []) - node_classifications.update( - { - entity_classification.get('name'): entity_classification.get('entity_type') - for entity_classification in entity_classifications - } - ) - # catch classification errors and continue if we can't classify - except Exception as e: - logger.exception(e) end = time() - logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms') + logger.debug(f'Extracted new nodes: {extracted_entities} in {(end - start) * 1000} ms') # Convert the extracted data into EntityNode objects extracted_nodes = [] - for name in extracted_node_names: - entity_type = node_classifications.get(name) - if entity_types is not None and entity_type not in entity_types: - entity_type = None - - labels = ( - ['Entity'] - if entity_type is None or entity_type == 'None' or entity_type == 'null' - else ['Entity', entity_type] + for extracted_entity in extracted_entities: + entity_type_name = entity_types_context[extracted_entity.entity_type_id].get( + 'entity_type_name' ) + labels: list[str] = list({'Entity', str(entity_type_name)}) + new_node = EntityNode( - name=name, + name=extracted_entity.name, group_id=episode.group_id, labels=labels, summary='', @@ -282,29 +231,29 @@ async def resolve_extracted_nodes( driver, extracted_nodes, SearchFilters() ) - uuid_map: dict[str, str] = {} - resolved_nodes: list[EntityNode] = [] - results: list[tuple[EntityNode, dict[str, str]]] = list( - await semaphore_gather( - *[ - resolve_extracted_node( - llm_client, - extracted_node, - existing_nodes, - episode, - previous_episodes, - entity_types, + resolved_nodes: list[EntityNode] = await semaphore_gather( + *[ + resolve_extracted_node( + llm_client, + extracted_node, + existing_nodes, + episode, + previous_episodes, + entity_types.get( + next((item for item in extracted_node.labels if item != 'Entity'), '') ) - for extracted_node, existing_nodes in zip( - extracted_nodes, existing_nodes_lists, strict=False - ) - ] - ) + if entity_types is not None + else None, + ) + for extracted_node, existing_nodes in zip( + extracted_nodes, existing_nodes_lists, strict=True + ) + ] ) - for result in results: - uuid_map.update(result[1]) - resolved_nodes.append(result[0]) + uuid_map: dict[str, str] = {} + for extracted_node, resolved_node in zip(extracted_nodes, resolved_nodes, strict=True): + uuid_map[extracted_node.uuid] = resolved_node.uuid logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}') @@ -317,124 +266,151 @@ async def resolve_extracted_node( existing_nodes: list[EntityNode], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, BaseModel] | None = None, -) -> tuple[EntityNode, dict[str, str]]: + entity_type: BaseModel | None = None, +) -> EntityNode: start = time() + if len(existing_nodes) == 0: + return extracted_node # Prepare context for LLM existing_nodes_context = [ - {**{'uuid': node.uuid, 'name': node.name, 'summary': node.summary}, **node.attributes} - for node in existing_nodes + { + **{ + 'id': i, + 'name': node.name, + 'entity_types': node.labels, + 'summary': node.summary, + }, + **node.attributes, + } + for i, node in enumerate(existing_nodes) ] extracted_node_context = { - 'uuid': extracted_node.uuid, 'name': extracted_node.name, - 'summary': extracted_node.summary, + 'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore + 'entity_type_description': entity_type.__doc__ + if entity_type is not None + else 'Default Entity Type', } context = { 'existing_nodes': existing_nodes_context, - 'extracted_nodes': extracted_node_context, + 'extracted_node': extracted_node_context, 'episode_content': episode.content if episode is not None else '', 'previous_episodes': [ep.content for ep in previous_episodes] if previous_episodes is not None else [], } - summary_context: dict[str, Any] = { - 'node_name': extracted_node.name, - 'node_summary': extracted_node.summary, - 'episode_content': episode.content if episode is not None else '', - 'previous_episodes': [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [], - } - - attributes: list[dict[str, str]] = [] - - entity_type_classes: tuple[BaseModel, ...] = tuple() - if entity_types is not None: # type: ignore - entity_type_classes = entity_type_classes + tuple( - filter( - lambda x: x is not None, # type: ignore - [entity_types.get(entity_type) for entity_type in extracted_node.labels], # type: ignore - ) - ) - - for entity_type in entity_type_classes: - for field_name, field_info in entity_type.model_fields.items(): - attributes.append( - { - 'attribute_name': field_name, - 'attribute_description': field_info.description or '', - } - ) - - summary_context['attributes'] = attributes - - entity_attributes_model = pydantic.create_model( # type: ignore - 'EntityAttributes', - __base__=entity_type_classes + (Summary,), # type: ignore + llm_response = await llm_client.generate_response( + prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate ) - llm_response, node_attributes_response = await semaphore_gather( - llm_client.generate_response( - prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate - ), - llm_client.generate_response( - prompt_library.summarize_nodes.summarize_context(summary_context), - response_model=entity_attributes_model, - ), + duplicate_id: int = llm_response.get('duplicate_node_id', -1) + + node = ( + existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node ) - extracted_node.summary = node_attributes_response.get('summary', '') - node_attributes = { - key: value if (value != 'None' or key == 'summary') else None - for key, value in node_attributes_response.items() - } - - with suppress(KeyError): - del node_attributes['summary'] - - extracted_node.attributes.update(node_attributes) - - is_duplicate: bool = llm_response.get('is_duplicate', False) - uuid: str | None = llm_response.get('uuid', None) - name = llm_response.get('name', '') - - node = extracted_node - uuid_map: dict[str, str] = {} - if is_duplicate: - for existing_node in existing_nodes: - if existing_node.uuid != uuid: - continue - summary_response = await llm_client.generate_response( - prompt_library.summarize_nodes.summarize_pair( - {'node_summaries': [extracted_node.summary, existing_node.summary]} - ), - response_model=Summary, - ) - node = existing_node - node.name = name - node.summary = summary_response.get('summary', '') - - new_attributes = extracted_node.attributes - existing_attributes = existing_node.attributes - for attribute_name, attribute_value in existing_attributes.items(): - if new_attributes.get(attribute_name) is None: - new_attributes[attribute_name] = attribute_value - node.attributes = new_attributes - node.labels = list(set(existing_node.labels + extracted_node.labels)) - - uuid_map[extracted_node.uuid] = existing_node.uuid - end = time() logger.debug( f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms' ) - return node, uuid_map + return node + + +async def extract_attributes_from_nodes( + clients: GraphitiClients, + nodes: list[EntityNode], + episode: EpisodicNode | None = None, + previous_episodes: list[EpisodicNode] | None = None, + entity_types: dict[str, BaseModel] | None = None, +) -> list[EntityNode]: + llm_client = clients.llm_client + embedder = clients.embedder + + updated_nodes: list[EntityNode] = await semaphore_gather( + *[ + extract_attributes_from_node( + llm_client, + node, + episode, + previous_episodes, + entity_types.get(next((item for item in node.labels if item != 'Entity'), '')) + if entity_types is not None + else None, + ) + for node in nodes + ] + ) + + await create_entity_node_embeddings(embedder, updated_nodes) + + return updated_nodes + + +async def extract_attributes_from_node( + llm_client: LLMClient, + node: EntityNode, + episode: EpisodicNode | None = None, + previous_episodes: list[EpisodicNode] | None = None, + entity_type: BaseModel | None = None, +) -> EntityNode: + node_context: dict[str, Any] = { + 'name': node.name, + 'summary': node.summary, + 'entity_types': node.labels, + 'attributes': node.attributes, + } + + attributes_definitions: dict[str, Any] = { + 'summary': ( + str, + Field( + description='Summary containing the important information about the entity. Under 200 words', + ), + ), + 'name': ( + str, + Field(description='Name of the ENTITY'), + ), + } + + if entity_type is not None: + for field_name, field_info in entity_type.model_fields.items(): + attributes_definitions[field_name] = ( + field_info.annotation, + Field(description=field_info.description), + ) + + entity_attributes_model = pydantic.create_model('EntityAttributes', **attributes_definitions) + + summary_context: dict[str, Any] = { + 'node': node_context, + 'episode_content': episode.content if episode is not None else '', + 'previous_episodes': [ep.content for ep in previous_episodes] + if previous_episodes is not None + else [], + } + + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.extract_attributes(summary_context), + response_model=entity_attributes_model, + ) + + node.summary = llm_response.get('summary', node.summary) + node.name = llm_response.get('name', node.name) + node_attributes = {key: value for key, value in llm_response.items()} + + with suppress(KeyError): + del node_attributes['summary'] + del node_attributes['name'] + + node.attributes.update(node_attributes) + + return node async def dedupe_node_list( diff --git a/graphiti_core/utils/maintenance/temporal_operations.py b/graphiti_core/utils/maintenance/temporal_operations.py index 98dcb4b..c63b8e2 100644 --- a/graphiti_core/utils/maintenance/temporal_operations.py +++ b/graphiti_core/utils/maintenance/temporal_operations.py @@ -72,12 +72,10 @@ async def get_edge_contradictions( llm_client: LLMClient, new_edge: EntityEdge, existing_edges: list[EntityEdge] ) -> list[EntityEdge]: start = time() - existing_edge_map = {edge.uuid: edge for edge in existing_edges} - new_edge_context = {'uuid': new_edge.uuid, 'name': new_edge.name, 'fact': new_edge.fact} + new_edge_context = {'fact': new_edge.fact} existing_edge_context = [ - {'uuid': existing_edge.uuid, 'name': existing_edge.name, 'fact': existing_edge.fact} - for existing_edge in existing_edges + {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges) ] context = {'new_edge': new_edge_context, 'existing_edges': existing_edge_context} @@ -86,14 +84,9 @@ async def get_edge_contradictions( prompt_library.invalidate_edges.v2(context), response_model=InvalidatedEdges ) - contradicted_edge_data = llm_response.get('invalidated_edges', []) + contradicted_facts: list[int] = llm_response.get('contradicted_facts', []) - contradicted_edges: list[EntityEdge] = [] - for edge_data in contradicted_edge_data: - if edge_data['uuid'] in existing_edge_map: - contradicted_edge = existing_edge_map[edge_data['uuid']] - contradicted_edge.fact = edge_data['fact'] - contradicted_edges.append(contradicted_edge) + contradicted_edges: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts] end = time() logger.debug( diff --git a/tests/evals/eval_e2e_graph_building.py b/tests/evals/eval_e2e_graph_building.py index fce284b..0a637e7 100644 --- a/tests/evals/eval_e2e_graph_building.py +++ b/tests/evals/eval_e2e_graph_building.py @@ -156,7 +156,7 @@ async def eval_graph(multi_session_count: int, session_length: int, llm_client=N baseline_results[user_id], add_episode_results[user_id], add_episode_context[user_id], - strict=True, + strict=False, ): context = { 'baseline': baseline_result, @@ -164,7 +164,6 @@ async def eval_graph(multi_session_count: int, session_length: int, llm_client=N 'message': episodes[0], 'previous_messages': episodes[1:], } - print(context) llm_response = await llm_client.generate_response( prompt_library.eval.eval_add_episode_results(context), diff --git a/tests/utils/maintenance/test_edge_operations.py b/tests/utils/maintenance/test_edge_operations.py index 97873df..bcd3ddd 100644 --- a/tests/utils/maintenance/test_edge_operations.py +++ b/tests/utils/maintenance/test_edge_operations.py @@ -103,16 +103,12 @@ async def test_resolve_extracted_edge_no_changes( ): # Mock the function calls dedupe_mock = AsyncMock(return_value=mock_extracted_edge) - extract_dates_mock = AsyncMock(return_value=(None, None)) get_contradictions_mock = AsyncMock(return_value=[]) # Patch the function calls monkeypatch.setattr( 'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock ) - monkeypatch.setattr( - 'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock - ) monkeypatch.setattr( 'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions', get_contradictions_mock, @@ -123,62 +119,14 @@ async def test_resolve_extracted_edge_no_changes( mock_extracted_edge, mock_related_edges, mock_existing_edges, - mock_current_episode, - mock_previous_episodes, ) assert resolved_edge.uuid == mock_extracted_edge.uuid assert invalidated_edges == [] dedupe_mock.assert_called_once() - extract_dates_mock.assert_called_once() get_contradictions_mock.assert_called_once() -@pytest.mark.asyncio -async def test_resolve_extracted_edge_with_dates( - mock_llm_client, - mock_extracted_edge, - mock_related_edges, - mock_existing_edges, - mock_current_episode, - mock_previous_episodes, - monkeypatch: MonkeyPatch, -): - valid_at = datetime.now(timezone.utc) - timedelta(days=1) - invalid_at = datetime.now(timezone.utc) + timedelta(days=1) - - # Mock the function calls - dedupe_mock = AsyncMock(return_value=mock_extracted_edge) - extract_dates_mock = AsyncMock(return_value=(valid_at, invalid_at)) - get_contradictions_mock = AsyncMock(return_value=[]) - - # Patch the function calls - monkeypatch.setattr( - 'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock - ) - monkeypatch.setattr( - 'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock - ) - monkeypatch.setattr( - 'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions', - get_contradictions_mock, - ) - - resolved_edge, invalidated_edges = await resolve_extracted_edge( - mock_llm_client, - mock_extracted_edge, - mock_related_edges, - mock_existing_edges, - mock_current_episode, - mock_previous_episodes, - ) - - assert resolved_edge.valid_at == valid_at - assert resolved_edge.invalid_at == invalid_at - assert resolved_edge.expired_at is not None - assert invalidated_edges == [] - - @pytest.mark.asyncio async def test_resolve_extracted_edge_with_invalidation( mock_llm_client, @@ -206,16 +154,12 @@ async def test_resolve_extracted_edge_with_invalidation( # Mock the function calls dedupe_mock = AsyncMock(return_value=mock_extracted_edge) - extract_dates_mock = AsyncMock(return_value=(None, None)) get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate]) # Patch the function calls monkeypatch.setattr( 'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock ) - monkeypatch.setattr( - 'graphiti_core.utils.maintenance.edge_operations.extract_edge_dates', extract_dates_mock - ) monkeypatch.setattr( 'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions', get_contradictions_mock, @@ -226,8 +170,6 @@ async def test_resolve_extracted_edge_with_invalidation( mock_extracted_edge, mock_related_edges, mock_existing_edges, - mock_current_episode, - mock_previous_episodes, ) assert resolved_edge.uuid == mock_extracted_edge.uuid diff --git a/tests/utils/maintenance/test_temporal_operations_int.py b/tests/utils/maintenance/test_temporal_operations_int.py index 7bb30ba..6b3b53d 100644 --- a/tests/utils/maintenance/test_temporal_operations_int.py +++ b/tests/utils/maintenance/test_temporal_operations_int.py @@ -25,7 +25,6 @@ from graphiti_core.llm_client import LLMConfig, OpenAIClient from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.maintenance.temporal_operations import ( - extract_edge_dates, get_edge_contradictions, ) @@ -265,67 +264,6 @@ async def test_invalidate_edges_partial_update(): assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated -def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]: - now = utc_now() - - previous_episodes = [ - EpisodicNode( - name='Previous Episode 1', - content='Bob: I work at XYZ company', - created_at=now - timedelta(days=2), - valid_at=now - timedelta(days=2), - source=EpisodeType.message, - source_description='Test previous episode for unit testing', - group_id='1', - ), - EpisodicNode( - name='Previous Episode 2', - content="Alice: That's really cool!", - created_at=now - timedelta(days=1), - valid_at=now - timedelta(days=1), - source=EpisodeType.message, - source_description='Test previous episode for unit testing', - group_id='1', - ), - ] - - episode = EpisodicNode( - name='Previous Episode', - content='Bob: It was cool, but I no longer work at company XYZ', - created_at=now, - valid_at=now, - source=EpisodeType.message, - source_description='Test previous episode for unit testing', - group_id='1', - ) - - return episode, previous_episodes - - -@pytest.mark.asyncio -@pytest.mark.integration -async def test_extract_edge_dates(): - episode, previous_episodes = create_data_for_temporal_extraction() - - # Create a new edge that partially updates an existing one - new_edge = EntityEdge( - uuid='e9', - source_node_uuid='2', - target_node_uuid='4', - name='LEFT_JOB', - fact='Bob no longer works at Company XYZ', - group_id='1', - created_at=utc_now(), - ) - - valid_at, invalid_at = await extract_edge_dates( - setup_llm_client(), new_edge, episode, previous_episodes - ) - - assert valid_at == episode.valid_at - assert invalid_at is None - - # Run the tests if __name__ == '__main__': pytest.main([__file__])