From 40e74a2e97f9d950f287d5842aa1b0f584fdfa47 Mon Sep 17 00:00:00 2001 From: Pavlo Paliychuk Date: Mon, 19 Aug 2024 09:37:56 -0400 Subject: [PATCH] fix: Address graph disconnect (#7) * fix: Address graph disconnect * chore: Remove valid_to and valid_from setting in extract edges step (will be handled during invalidation step) --- core/graphiti.py | 14 +++++++-- core/prompts/extract_edges.py | 13 ++++---- core/prompts/extract_nodes.py | 37 ++++++++++++++++++++++- core/utils/maintenance/edge_operations.py | 29 +++++++++--------- core/utils/maintenance/node_operations.py | 2 +- runner.py | 2 +- 6 files changed, 72 insertions(+), 25 deletions(-) diff --git a/core/graphiti.py b/core/graphiti.py index 83b40d1f..1a07227f 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -113,12 +113,16 @@ class Graphiti: await asyncio.gather( *[node.generate_name_embedding(embedder) for node in extracted_nodes] ) - existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver) - + logger.info( + f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}" + ) new_nodes = await dedupe_extracted_nodes( self.llm_client, extracted_nodes, existing_nodes ) + logger.info( + f"Deduped touched nodes: {[(n.name, n.uuid) for n in new_nodes]}" + ) nodes.extend(new_nodes) extracted_edges = await extract_edges( @@ -130,11 +134,17 @@ class Graphiti: ) existing_edges = await get_relevant_edges(extracted_edges, self.driver) + logger.info(f"Existing edges: {[(e.name, e.uuid) for e in existing_edges]}") + logger.info( + f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}" + ) new_edges = await dedupe_extracted_edges( self.llm_client, extracted_edges, existing_edges ) + logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}") + entity_edges.extend(new_edges) episodic_edges.extend( build_episodic_edges( diff --git a/core/prompts/extract_edges.py b/core/prompts/extract_edges.py index cc12206a..b60cc807 100644 --- a/core/prompts/extract_edges.py +++ b/core/prompts/extract_edges.py @@ -135,17 +135,18 @@ def v2(context: dict[str, any]) -> list[Message]: Message( role="user", content=f""" - Given the following context, extract new edges (relationships) that need to be added to the knowledge graph: + Given the following context, extract edges (relationships) that need to be added to the knowledge graph: Nodes: {json.dumps(context['nodes'], indent=2)} - New Episode: - Content: {context['episode_content']} + - Previous Episodes: + Episodes: {json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)} + {context['episode_content']} <-- New Episode + - Extract new entity edges based on the content of the current episode, the given nodes, and context from previous episodes. + Extract entity edges based on the content of the current episode, the given nodes, and context from previous episodes. Guidelines: 1. Create edges only between the provided nodes. @@ -168,7 +169,7 @@ def v2(context: dict[str, any]) -> list[Message]: ] }} - If no new edges need to be added, return an empty list for "new_edges". + If no edges need to be added, return an empty list for "edges". """, ), ] diff --git a/core/prompts/extract_nodes.py b/core/prompts/extract_nodes.py index 8e3b1f55..3b54ad9e 100644 --- a/core/prompts/extract_nodes.py +++ b/core/prompts/extract_nodes.py @@ -7,11 +7,13 @@ from .models import Message, PromptVersion, PromptFunction class Prompt(Protocol): v1: PromptVersion v2: PromptVersion + v3: PromptVersion class Versions(TypedDict): v1: PromptFunction v2: PromptFunction + v3: PromptFunction def v1(context: dict[str, any]) -> list[Message]: @@ -103,4 +105,37 @@ def v2(context: dict[str, any]) -> list[Message]: ] -versions: Versions = {"v1": v1, "v2": v2} +def v3(context: dict[str, any]) -> list[Message]: + sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation.""" + + user_prompt = f""" +Given the following conversation, extract entity nodes that are explicitly or implicitly mentioned: + +Conversation: +{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)} +{context["episode_content"]} + +Guidelines: +1. ALWAYS extract the speaker/actor as the first node. The speaker is the part before the colon in each line of dialogue. +2. Extract other significant entities, concepts, or actors mentioned in the conversation. +3. Provide concise but informative summaries for each extracted node. +4. Avoid creating nodes for relationships or actions. + +Respond with a JSON object in the following format: +{{ + "new_nodes": [ + {{ + "name": "Unique identifier for the node (use the speaker's name for speaker nodes)", + "labels": ["Entity", "Speaker" for speaker nodes, "OptionalAdditionalLabel"], + "summary": "Brief summary of the node's role or significance" + }} + ] +}} +""" + return [ + Message(role="system", content=sys_prompt), + Message(role="user", content=user_prompt), + ] + + +versions: Versions = {"v1": v1, "v2": v2, "v3": v3} diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index 47e004cc..b78fdc3a 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -170,20 +170,21 @@ async def extract_edges( # Convert the extracted data into EntityEdge objects edges = [] for edge_data in edges_data: - edge = EntityEdge( - source_node_uuid=edge_data["source_node_uuid"], - target_node_uuid=edge_data["target_node_uuid"], - name=edge_data["relation_type"], - fact=edge_data["fact"], - episodes=[episode.uuid], - created_at=datetime.now(), - valid_at=edge_data["valid_at"], - invalid_at=edge_data["invalid_at"], - ) - edges.append(edge) - logger.info( - f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})" - ) + if edge_data["target_node_uuid"] and edge_data["source_node_uuid"]: + edge = EntityEdge( + source_node_uuid=edge_data["source_node_uuid"], + target_node_uuid=edge_data["target_node_uuid"], + name=edge_data["relation_type"], + fact=edge_data["fact"], + episodes=[episode.uuid], + created_at=datetime.now(), + valid_at=None, + invalid_at=None, + ) + edges.append(edge) + logger.info( + f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})" + ) return edges diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py index 571c3316..1b67bf56 100644 --- a/core/utils/maintenance/node_operations.py +++ b/core/utils/maintenance/node_operations.py @@ -84,7 +84,7 @@ async def extract_nodes( } llm_response = await llm_client.generate_response( - prompt_library.extract_nodes.v2(context) + prompt_library.extract_nodes.v3(context) ) new_nodes_data = llm_response.get("new_nodes", []) logger.info(f"Extracted new nodes: {new_nodes_data}") diff --git a/runner.py b/runner.py index 967d16ea..cbfd9371 100644 --- a/runner.py +++ b/runner.py @@ -49,7 +49,7 @@ async def main(): ) await client.add_episode( name="Message 2", - episode_body="Paul: I love bananas", + episode_body="Paul: I own many bananas", source_description="WhatsApp Message", ) await client.add_episode(