mirror of
https://github.com/getzep/graphiti.git
synced 2025-12-28 23:57:44 +00:00
Custom ontology (#262)
* ontology * extract and save node labels * extract entity type properties * neo4j upgrade needed * add entity types * update typing * update types * updates * Update graphiti_core/utils/maintenance/node_operations.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * fix warning * mypy updates * update properties * mypy ignore * mypy types * bump version --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
6eccc9eecd
commit
29a071b2b8
@ -1,142 +0,0 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from time import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
|
||||
|
||||
load_dotenv()
|
||||
|
||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
|
||||
search_start = time()
|
||||
results = await graphiti._search(
|
||||
query,
|
||||
COMBINED_HYBRID_SEARCH_RRF,
|
||||
group_ids=[str(group_id)],
|
||||
)
|
||||
search_end = time()
|
||||
search_duration = search_end - search_start
|
||||
|
||||
facts = [edge.fact for edge in results.edges]
|
||||
entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
|
||||
context = {
|
||||
'facts': facts,
|
||||
'entity_summaries': entity_summaries,
|
||||
'query': 'Bob: ' + query,
|
||||
}
|
||||
|
||||
llm_response = await graphiti.llm_client.generate_response(
|
||||
prompt_library.eval.qa_prompt(context)
|
||||
)
|
||||
response = llm_response.get('ANSWER', '')
|
||||
|
||||
eval_context = {
|
||||
'query': 'Bob: ' + query,
|
||||
'answer': 'Alice: ' + answer,
|
||||
'response': 'Alice: ' + response,
|
||||
}
|
||||
|
||||
eval_llm_response = await graphiti.llm_client.generate_response(
|
||||
prompt_library.eval.eval_prompt(eval_context)
|
||||
)
|
||||
eval_response = 1 if eval_llm_response.get('is_correct', False) else 0
|
||||
|
||||
return {
|
||||
'Group id': group_id,
|
||||
'Question': query,
|
||||
'Answer': answer,
|
||||
'Response': response,
|
||||
'Score': eval_response,
|
||||
'Search Duration (ms)': search_duration * 1000,
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
fields = [
|
||||
'Group id',
|
||||
'Question',
|
||||
'Answer',
|
||||
'Response',
|
||||
'Score',
|
||||
'Search Duration (ms)',
|
||||
]
|
||||
with open('../data/msc_eval.csv', 'w', newline='') as file:
|
||||
writer = csv.DictWriter(file, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
|
||||
qa = conversation_q_and_a()[0:500]
|
||||
i = 0
|
||||
while i < 500:
|
||||
qa_chunk = qa[i : i + 20]
|
||||
group_ids = range(len(qa))[i : i + 20]
|
||||
results = list(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
evaluate_qa(graphiti, str(group_id), query, answer)
|
||||
for group_id, (query, answer) in zip(group_ids, qa_chunk)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with open('../data/msc_eval.csv', 'a', newline='') as file:
|
||||
writer = csv.DictWriter(file, fieldnames=fields)
|
||||
writer.writerows(results)
|
||||
i += 20
|
||||
|
||||
await graphiti.close()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
@ -1,89 +0,0 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from examples.multi_session_conversation_memory.parse_msc_messages import (
|
||||
ParsedMscMessage,
|
||||
parse_msc_messages,
|
||||
)
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
|
||||
load_dotenv()
|
||||
|
||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
async def add_conversation(graphiti: Graphiti, group_id: str, messages: list[ParsedMscMessage]):
|
||||
for i, message in enumerate(messages):
|
||||
await graphiti.add_episode(
|
||||
name=f'Message {group_id + "-" + str(i)}',
|
||||
episode_body=f'{message.speaker_name}: {message.content}',
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Multi-Session Conversation',
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
msc_messages = parse_msc_messages()
|
||||
i = 0
|
||||
while i < len(msc_messages):
|
||||
msc_message_slice = msc_messages[i : i + 10]
|
||||
group_ids = range(len(msc_messages))[i : i + 10]
|
||||
|
||||
await semaphore_gather(
|
||||
*[
|
||||
add_conversation(graphiti, str(group_id), messages)
|
||||
for group_id, messages in zip(group_ids, msc_message_slice)
|
||||
]
|
||||
)
|
||||
|
||||
i += 10
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
@ -1,85 +0,0 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ParsedMscMessage(BaseModel):
|
||||
speaker_name: str
|
||||
actual_timestamp: datetime
|
||||
content: str
|
||||
group_id: str
|
||||
|
||||
|
||||
def parse_msc_messages() -> list[list[ParsedMscMessage]]:
|
||||
msc_messages: list[list[ParsedMscMessage]] = []
|
||||
speakers = ['Alice', 'Bob']
|
||||
|
||||
with open('../data/msc.jsonl') as file:
|
||||
data = [json.loads(line) for line in file]
|
||||
for i, conversation in enumerate(data):
|
||||
messages: list[ParsedMscMessage] = []
|
||||
for previous_dialog in conversation['previous_dialogs']:
|
||||
dialog = previous_dialog['dialog']
|
||||
speaker_idx = 0
|
||||
|
||||
for utterance in dialog:
|
||||
content = utterance['text']
|
||||
messages.append(
|
||||
ParsedMscMessage(
|
||||
speaker_name=speakers[speaker_idx],
|
||||
content=content,
|
||||
actual_timestamp=datetime.now(timezone.utc),
|
||||
group_id=str(i),
|
||||
)
|
||||
)
|
||||
speaker_idx += 1
|
||||
speaker_idx %= 2
|
||||
|
||||
dialog = conversation['dialog']
|
||||
speaker_idx = 0
|
||||
for utterance in dialog:
|
||||
content = utterance['text']
|
||||
messages.append(
|
||||
ParsedMscMessage(
|
||||
speaker_name=speakers[speaker_idx],
|
||||
content=content,
|
||||
actual_timestamp=datetime.now(timezone.utc),
|
||||
group_id=str(i),
|
||||
)
|
||||
)
|
||||
speaker_idx += 1
|
||||
speaker_idx %= 2
|
||||
|
||||
msc_messages.append(messages)
|
||||
|
||||
return msc_messages
|
||||
|
||||
|
||||
def conversation_q_and_a() -> list[tuple[str, str]]:
|
||||
with open('../data/msc.jsonl') as file:
|
||||
data = [json.loads(line) for line in file]
|
||||
|
||||
qa: list[tuple[str, str]] = []
|
||||
for conversation in data:
|
||||
query = conversation['self_instruct']['B']
|
||||
answer = conversation['self_instruct']['A']
|
||||
|
||||
qa.append((query, answer))
|
||||
return qa
|
||||
@ -20,6 +20,7 @@ import os
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
from transcript_parser import parse_podcast_messages
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
@ -53,6 +54,12 @@ def setup_logging():
|
||||
return logger
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
first_name: str | None = Field(..., description='First name')
|
||||
last_name: str | None = Field(..., description='Last name')
|
||||
occupation: str | None = Field(..., description="The person's work occupation")
|
||||
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
@ -67,6 +74,7 @@ async def main():
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Podcast Transcript',
|
||||
group_id='podcast',
|
||||
entity_types={'Person': Person},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -262,6 +262,7 @@ class Graphiti:
|
||||
group_id: str = '',
|
||||
uuid: str | None = None,
|
||||
update_communities: bool = False,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
) -> AddEpisodeResults:
|
||||
"""
|
||||
Process an episode and update the graph.
|
||||
@ -336,7 +337,9 @@ class Graphiti:
|
||||
|
||||
# Extract entities as nodes
|
||||
|
||||
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
|
||||
extracted_nodes = await extract_nodes(
|
||||
self.llm_client, episode, previous_episodes, entity_types
|
||||
)
|
||||
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
|
||||
|
||||
# Calculate Embeddings
|
||||
@ -362,6 +365,7 @@ class Graphiti:
|
||||
existing_nodes_lists,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
),
|
||||
extract_edges(
|
||||
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
|
||||
|
||||
@ -31,14 +31,16 @@ EPISODIC_NODE_SAVE_BULK = """
|
||||
|
||||
ENTITY_NODE_SAVE = """
|
||||
MERGE (n:Entity {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
SET n:$($labels)
|
||||
SET n = $entity_data
|
||||
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
|
||||
RETURN n.uuid AS uuid"""
|
||||
|
||||
ENTITY_NODE_SAVE_BULK = """
|
||||
UNWIND $nodes AS node
|
||||
MERGE (n:Entity {uuid: node.uuid})
|
||||
SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
|
||||
SET n:$(node.labels)
|
||||
SET n = node
|
||||
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
|
||||
RETURN n.uuid AS uuid
|
||||
"""
|
||||
|
||||
@ -255,6 +255,9 @@ class EpisodicNode(Node):
|
||||
class EntityNode(Node):
|
||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
||||
attributes: dict[str, Any] = Field(
|
||||
default={}, description='Additional attributes of the node. Dependent on node labels'
|
||||
)
|
||||
|
||||
async def generate_name_embedding(self, embedder: EmbedderClient):
|
||||
start = time()
|
||||
@ -266,14 +269,21 @@ class EntityNode(Node):
|
||||
return self.name_embedding
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
entity_data: dict[str, Any] = {
|
||||
'uuid': self.uuid,
|
||||
'name': self.name,
|
||||
'name_embedding': self.name_embedding,
|
||||
'group_id': self.group_id,
|
||||
'summary': self.summary,
|
||||
'created_at': self.created_at,
|
||||
}
|
||||
|
||||
entity_data.update(self.attributes or {})
|
||||
|
||||
result = await driver.execute_query(
|
||||
ENTITY_NODE_SAVE,
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
summary=self.summary,
|
||||
name_embedding=self.name_embedding,
|
||||
created_at=self.created_at,
|
||||
labels=self.labels + ['Entity'],
|
||||
entity_data=entity_data,
|
||||
database_=DEFAULT_DATABASE,
|
||||
)
|
||||
|
||||
@ -292,7 +302,9 @@ class EntityNode(Node):
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
""",
|
||||
uuid=uuid,
|
||||
database_=DEFAULT_DATABASE,
|
||||
@ -317,7 +329,9 @@ class EntityNode(Node):
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
""",
|
||||
uuids=uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
@ -351,7 +365,9 @@ class EntityNode(Node):
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
ORDER BY n.uuid DESC
|
||||
"""
|
||||
+ limit_query,
|
||||
@ -503,9 +519,10 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
||||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
name_embedding=record['name_embedding'],
|
||||
labels=['Entity'],
|
||||
labels=record['labels'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
attributes=record['attributes'],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -30,11 +30,19 @@ class MissedEntities(BaseModel):
|
||||
missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
|
||||
|
||||
|
||||
class EntityClassification(BaseModel):
|
||||
entity_classification: str = Field(
|
||||
...,
|
||||
description='Dictionary of entity classifications. Key is the entity name and value is the entity type',
|
||||
)
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
extract_message: PromptVersion
|
||||
extract_json: PromptVersion
|
||||
extract_text: PromptVersion
|
||||
reflexion: PromptVersion
|
||||
classify_nodes: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
@ -42,6 +50,7 @@ class Versions(TypedDict):
|
||||
extract_json: PromptFunction
|
||||
extract_text: PromptFunction
|
||||
reflexion: PromptFunction
|
||||
classify_nodes: PromptFunction
|
||||
|
||||
|
||||
def extract_message(context: dict[str, Any]) -> list[Message]:
|
||||
@ -66,6 +75,7 @@ Guidelines:
|
||||
4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
|
||||
5. Be as explicit as possible in your node names, using full names.
|
||||
6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
|
||||
7. Extract preferences as their own nodes
|
||||
"""
|
||||
return [
|
||||
Message(role='system', content=sys_prompt),
|
||||
@ -109,7 +119,7 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
|
||||
|
||||
{context['custom_prompt']}
|
||||
|
||||
Given the following text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned:
|
||||
Given the above text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned:
|
||||
|
||||
Guidelines:
|
||||
1. Extract significant entities, concepts, or actors mentioned in the conversation.
|
||||
@ -147,9 +157,41 @@ extracted.
|
||||
]
|
||||
|
||||
|
||||
def classify_nodes(context: dict[str, Any]) -> list[Message]:
|
||||
sys_prompt = """You are an AI assistant that classifies entity nodes given the context from which they were extracted"""
|
||||
|
||||
user_prompt = f"""
|
||||
<PREVIOUS MESSAGES>
|
||||
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
|
||||
</PREVIOUS MESSAGES>
|
||||
<CURRENT MESSAGE>
|
||||
{context["episode_content"]}
|
||||
</CURRENT MESSAGE>
|
||||
|
||||
<EXTRACTED ENTITIES>
|
||||
{context['extracted_entities']}
|
||||
</EXTRACTED ENTITIES>
|
||||
|
||||
<ENTITY TYPES>
|
||||
{context['entity_types']}
|
||||
</ENTITY TYPES>
|
||||
|
||||
Given the above conversation, extracted entities, and provided entity types, classify the extracted entities.
|
||||
|
||||
Guidelines:
|
||||
1. Each entity must have exactly one type
|
||||
2. If none of the provided entity types accurately classify an extracted node, the type should be set to None
|
||||
"""
|
||||
return [
|
||||
Message(role='system', content=sys_prompt),
|
||||
Message(role='user', content=user_prompt),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {
|
||||
'extract_message': extract_message,
|
||||
'extract_json': extract_json,
|
||||
'extract_text': extract_text,
|
||||
'reflexion': reflexion,
|
||||
'classify_nodes': classify_nodes,
|
||||
}
|
||||
|
||||
@ -24,7 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
|
||||
|
||||
class Summary(BaseModel):
|
||||
summary: str = Field(
|
||||
..., description='Summary containing the important information from both summaries'
|
||||
...,
|
||||
description='Summary containing the important information about the entity. Under 500 words',
|
||||
)
|
||||
|
||||
|
||||
@ -68,7 +69,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are a helpful assistant that combines summaries with new conversation context.',
|
||||
content='You are a helpful assistant that extracts entity properties from the provided text.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
@ -79,18 +80,23 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
||||
{json.dumps(context['episode_content'], indent=2)}
|
||||
</MESSAGES>
|
||||
|
||||
Given the above MESSAGES and the following ENTITY name and ENTITY CONTEXT, create a summary for the ENTITY. Your summary must only use
|
||||
information from the provided MESSAGES and from the ENTITY CONTEXT. Your summary should also only contain information relevant to the
|
||||
provided ENTITY.
|
||||
Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
|
||||
information from the provided MESSAGES. Your summary should also only contain information relevant to the
|
||||
provided ENTITY. Summaries must be under 500 words.
|
||||
|
||||
Summaries must be under 500 words.
|
||||
In addition, extract any values for the provided entity properties based on their descriptions.
|
||||
|
||||
<ENTITY>
|
||||
{context['node_name']}
|
||||
</ENTITY>
|
||||
|
||||
<ENTITY CONTEXT>
|
||||
{context['node_summary']}
|
||||
</ENTITY CONTEXT>
|
||||
|
||||
<ATTRIBUTES>
|
||||
{json.dumps(context['attributes'], indent=2)}
|
||||
</ATTRIBUTES>
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
@ -97,7 +97,9 @@ async def get_mentioned_nodes(
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
""",
|
||||
uuids=episode_uuids,
|
||||
database_=DEFAULT_DATABASE,
|
||||
@ -223,8 +225,8 @@ async def edge_similarity_search(
|
||||
|
||||
query: LiteralString = (
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ filter_query
|
||||
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
||||
@ -341,7 +343,9 @@ async def node_fulltext_search(
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
@ -390,7 +394,9 @@ async def node_similarity_search(
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
@ -427,7 +433,9 @@ async def node_bfs_search(
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
n.summary AS summary,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS attributes
|
||||
LIMIT $limit
|
||||
""",
|
||||
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
||||
|
||||
@ -23,6 +23,7 @@ from math import ceil
|
||||
from neo4j import AsyncDriver, AsyncManagedTransaction
|
||||
from numpy import dot, sqrt
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Any
|
||||
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.helpers import semaphore_gather
|
||||
@ -109,8 +110,23 @@ async def add_nodes_and_edges_bulk_tx(
|
||||
episodes = [dict(episode) for episode in episodic_nodes]
|
||||
for episode in episodes:
|
||||
episode['source'] = str(episode['source'].value)
|
||||
nodes: list[dict[str, Any]] = []
|
||||
for node in entity_nodes:
|
||||
entity_data: dict[str, Any] = {
|
||||
'uuid': node.uuid,
|
||||
'name': node.name,
|
||||
'name_embedding': node.name_embedding,
|
||||
'group_id': node.group_id,
|
||||
'summary': node.summary,
|
||||
'created_at': node.created_at,
|
||||
}
|
||||
|
||||
entity_data.update(node.attributes or {})
|
||||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||||
nodes.append(entity_data)
|
||||
|
||||
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes])
|
||||
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
|
||||
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
|
||||
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
|
||||
|
||||
|
||||
@ -14,15 +14,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import logging
|
||||
from time import time
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||
from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities
|
||||
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
|
||||
from graphiti_core.prompts.summarize_nodes import Summary
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
|
||||
@ -114,6 +118,7 @@ async def extract_nodes(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
extracted_node_names: list[str] = []
|
||||
@ -144,15 +149,35 @@ async def extract_nodes(
|
||||
for entity in missing_entities:
|
||||
custom_prompt += f'\n{entity},'
|
||||
|
||||
node_classification_context = {
|
||||
'episode_content': episode.content,
|
||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||
'extracted_entities': extracted_node_names,
|
||||
'entity_types': entity_types.keys() if entity_types is not None else [],
|
||||
}
|
||||
|
||||
node_classifications: dict[str, str | None] = {}
|
||||
|
||||
if entity_types is not None:
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.classify_nodes(node_classification_context),
|
||||
response_model=EntityClassification,
|
||||
)
|
||||
response_string = llm_response.get('entity_classification', '{}')
|
||||
node_classifications.update(ast.literal_eval(response_string))
|
||||
|
||||
end = time()
|
||||
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
|
||||
# Convert the extracted data into EntityNode objects
|
||||
new_nodes = []
|
||||
for name in extracted_node_names:
|
||||
entity_type = node_classifications.get(name)
|
||||
labels = ['Entity'] if entity_type is None else ['Entity', entity_type]
|
||||
|
||||
new_node = EntityNode(
|
||||
name=name,
|
||||
group_id=episode.group_id,
|
||||
labels=['Entity'],
|
||||
labels=labels,
|
||||
summary='',
|
||||
created_at=utc_now(),
|
||||
)
|
||||
@ -218,6 +243,7 @@ async def resolve_extracted_nodes(
|
||||
existing_nodes_lists: list[list[EntityNode]],
|
||||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
uuid_map: dict[str, str] = {}
|
||||
resolved_nodes: list[EntityNode] = []
|
||||
@ -225,7 +251,12 @@ async def resolve_extracted_nodes(
|
||||
await semaphore_gather(
|
||||
*[
|
||||
resolve_extracted_node(
|
||||
llm_client, extracted_node, existing_nodes, episode, previous_episodes
|
||||
llm_client,
|
||||
extracted_node,
|
||||
existing_nodes,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
)
|
||||
for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists)
|
||||
]
|
||||
@ -245,6 +276,7 @@ async def resolve_extracted_node(
|
||||
existing_nodes: list[EntityNode],
|
||||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
) -> tuple[EntityNode, dict[str, str]]:
|
||||
start = time()
|
||||
|
||||
@ -273,19 +305,39 @@ async def resolve_extracted_node(
|
||||
'previous_episodes': [ep.content for ep in previous_episodes]
|
||||
if previous_episodes is not None
|
||||
else [],
|
||||
'attributes': [],
|
||||
}
|
||||
|
||||
llm_response, node_summary_response = await semaphore_gather(
|
||||
entity_type_classes: tuple[BaseModel, ...] = tuple()
|
||||
if entity_types is not None: # type: ignore
|
||||
entity_type_classes = entity_type_classes + tuple(
|
||||
filter(
|
||||
lambda x: x is not None, # type: ignore
|
||||
[entity_types.get(entity_type) for entity_type in extracted_node.labels], # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
for entity_type in entity_type_classes:
|
||||
for field_name in entity_type.model_fields:
|
||||
summary_context.get('attributes', []).append(field_name) # type: ignore
|
||||
|
||||
entity_attributes_model = pydantic.create_model( # type: ignore
|
||||
'EntityAttributes',
|
||||
__base__=entity_type_classes + (Summary,), # type: ignore
|
||||
)
|
||||
|
||||
llm_response, node_attributes_response = await semaphore_gather(
|
||||
llm_client.generate_response(
|
||||
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
|
||||
),
|
||||
llm_client.generate_response(
|
||||
prompt_library.summarize_nodes.summarize_context(summary_context),
|
||||
response_model=Summary,
|
||||
response_model=entity_attributes_model,
|
||||
),
|
||||
)
|
||||
|
||||
extracted_node.summary = node_summary_response.get('summary', '')
|
||||
extracted_node.summary = node_attributes_response.get('summary', '')
|
||||
extracted_node.attributes.update(node_attributes_response)
|
||||
|
||||
is_duplicate: bool = llm_response.get('is_duplicate', False)
|
||||
uuid: str | None = llm_response.get('uuid', None)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.6.1"
|
||||
version = "0.7.0"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user