Custom ontology (#262)

* ontology

* extract and save node labels

* extract entity type properties

* neo4j upgrade needed

* add entity types

* update typing

* update types

* updates

* Update graphiti_core/utils/maintenance/node_operations.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* fix warning

* mypy updates

* update properties

* mypy ignore

* mypy types

* bump version

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-02-13 12:17:52 -05:00 committed by GitHub
parent 6eccc9eecd
commit 29a071b2b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 189 additions and 350 deletions

View File

@ -1,142 +0,0 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import csv
import logging
import os
import sys
from time import time
from dotenv import load_dotenv
from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a
from graphiti_core import Graphiti
from graphiti_core.helpers import semaphore_gather
from graphiti_core.prompts import prompt_library
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
search_start = time()
results = await graphiti._search(
query,
COMBINED_HYBRID_SEARCH_RRF,
group_ids=[str(group_id)],
)
search_end = time()
search_duration = search_end - search_start
facts = [edge.fact for edge in results.edges]
entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
context = {
'facts': facts,
'entity_summaries': entity_summaries,
'query': 'Bob: ' + query,
}
llm_response = await graphiti.llm_client.generate_response(
prompt_library.eval.qa_prompt(context)
)
response = llm_response.get('ANSWER', '')
eval_context = {
'query': 'Bob: ' + query,
'answer': 'Alice: ' + answer,
'response': 'Alice: ' + response,
}
eval_llm_response = await graphiti.llm_client.generate_response(
prompt_library.eval.eval_prompt(eval_context)
)
eval_response = 1 if eval_llm_response.get('is_correct', False) else 0
return {
'Group id': group_id,
'Question': query,
'Answer': answer,
'Response': response,
'Score': eval_response,
'Search Duration (ms)': search_duration * 1000,
}
async def main():
setup_logging()
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
fields = [
'Group id',
'Question',
'Answer',
'Response',
'Score',
'Search Duration (ms)',
]
with open('../data/msc_eval.csv', 'w', newline='') as file:
writer = csv.DictWriter(file, fieldnames=fields)
writer.writeheader()
qa = conversation_q_and_a()[0:500]
i = 0
while i < 500:
qa_chunk = qa[i : i + 20]
group_ids = range(len(qa))[i : i + 20]
results = list(
await semaphore_gather(
*[
evaluate_qa(graphiti, str(group_id), query, answer)
for group_id, (query, answer) in zip(group_ids, qa_chunk)
]
)
)
with open('../data/msc_eval.csv', 'a', newline='') as file:
writer = csv.DictWriter(file, fieldnames=fields)
writer.writerows(results)
i += 20
await graphiti.close()
asyncio.run(main())

View File

@ -1,89 +0,0 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
import os
import sys
from dotenv import load_dotenv
from examples.multi_session_conversation_memory.parse_msc_messages import (
ParsedMscMessage,
parse_msc_messages,
)
from graphiti_core import Graphiti
from graphiti_core.helpers import semaphore_gather
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
async def add_conversation(graphiti: Graphiti, group_id: str, messages: list[ParsedMscMessage]):
for i, message in enumerate(messages):
await graphiti.add_episode(
name=f'Message {group_id + "-" + str(i)}',
episode_body=f'{message.speaker_name}: {message.content}',
reference_time=message.actual_timestamp,
source_description='Multi-Session Conversation',
group_id=group_id,
)
async def main():
setup_logging()
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
msc_messages = parse_msc_messages()
i = 0
while i < len(msc_messages):
msc_message_slice = msc_messages[i : i + 10]
group_ids = range(len(msc_messages))[i : i + 10]
await semaphore_gather(
*[
add_conversation(graphiti, str(group_id), messages)
for group_id, messages in zip(group_ids, msc_message_slice)
]
)
i += 10
asyncio.run(main())

View File

@ -1,85 +0,0 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from datetime import datetime, timezone
from pydantic import BaseModel
class ParsedMscMessage(BaseModel):
speaker_name: str
actual_timestamp: datetime
content: str
group_id: str
def parse_msc_messages() -> list[list[ParsedMscMessage]]:
msc_messages: list[list[ParsedMscMessage]] = []
speakers = ['Alice', 'Bob']
with open('../data/msc.jsonl') as file:
data = [json.loads(line) for line in file]
for i, conversation in enumerate(data):
messages: list[ParsedMscMessage] = []
for previous_dialog in conversation['previous_dialogs']:
dialog = previous_dialog['dialog']
speaker_idx = 0
for utterance in dialog:
content = utterance['text']
messages.append(
ParsedMscMessage(
speaker_name=speakers[speaker_idx],
content=content,
actual_timestamp=datetime.now(timezone.utc),
group_id=str(i),
)
)
speaker_idx += 1
speaker_idx %= 2
dialog = conversation['dialog']
speaker_idx = 0
for utterance in dialog:
content = utterance['text']
messages.append(
ParsedMscMessage(
speaker_name=speakers[speaker_idx],
content=content,
actual_timestamp=datetime.now(timezone.utc),
group_id=str(i),
)
)
speaker_idx += 1
speaker_idx %= 2
msc_messages.append(messages)
return msc_messages
def conversation_q_and_a() -> list[tuple[str, str]]:
with open('../data/msc.jsonl') as file:
data = [json.loads(line) for line in file]
qa: list[tuple[str, str]] = []
for conversation in data:
query = conversation['self_instruct']['B']
answer = conversation['self_instruct']['A']
qa.append((query, answer))
return qa

View File

@ -20,6 +20,7 @@ import os
import sys
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti
@ -53,6 +54,12 @@ def setup_logging():
return logger
class Person(BaseModel):
first_name: str | None = Field(..., description='First name')
last_name: str | None = Field(..., description='Last name')
occupation: str | None = Field(..., description="The person's work occupation")
async def main():
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
@ -67,6 +74,7 @@ async def main():
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='podcast',
entity_types={'Person': Person},
)

View File

@ -262,6 +262,7 @@ class Graphiti:
group_id: str = '',
uuid: str | None = None,
update_communities: bool = False,
entity_types: dict[str, BaseModel] | None = None,
) -> AddEpisodeResults:
"""
Process an episode and update the graph.
@ -336,7 +337,9 @@ class Graphiti:
# Extract entities as nodes
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
extracted_nodes = await extract_nodes(
self.llm_client, episode, previous_episodes, entity_types
)
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
# Calculate Embeddings
@ -362,6 +365,7 @@ class Graphiti:
existing_nodes_lists,
episode,
previous_episodes,
entity_types,
),
extract_edges(
self.llm_client, episode, extracted_nodes, previous_episodes, group_id

View File

@ -31,14 +31,16 @@ EPISODIC_NODE_SAVE_BULK = """
ENTITY_NODE_SAVE = """
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
SET n:$($labels)
SET n = $entity_data
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""
ENTITY_NODE_SAVE_BULK = """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
SET n:$(node.labels)
SET n = node
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
"""

View File

@ -255,6 +255,9 @@ class EpisodicNode(Node):
class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
attributes: dict[str, Any] = Field(
default={}, description='Additional attributes of the node. Dependent on node labels'
)
async def generate_name_embedding(self, embedder: EmbedderClient):
start = time()
@ -266,14 +269,21 @@ class EntityNode(Node):
return self.name_embedding
async def save(self, driver: AsyncDriver):
entity_data: dict[str, Any] = {
'uuid': self.uuid,
'name': self.name,
'name_embedding': self.name_embedding,
'group_id': self.group_id,
'summary': self.summary,
'created_at': self.created_at,
}
entity_data.update(self.attributes or {})
result = await driver.execute_query(
ENTITY_NODE_SAVE,
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
labels=self.labels + ['Entity'],
entity_data=entity_data,
database_=DEFAULT_DATABASE,
)
@ -292,7 +302,9 @@ class EntityNode(Node):
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
""",
uuid=uuid,
database_=DEFAULT_DATABASE,
@ -317,7 +329,9 @@ class EntityNode(Node):
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
""",
uuids=uuids,
database_=DEFAULT_DATABASE,
@ -351,7 +365,9 @@ class EntityNode(Node):
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
ORDER BY n.uuid DESC
"""
+ limit_query,
@ -503,9 +519,10 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
name=record['name'],
group_id=record['group_id'],
name_embedding=record['name_embedding'],
labels=['Entity'],
labels=record['labels'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
attributes=record['attributes'],
)

View File

@ -30,11 +30,19 @@ class MissedEntities(BaseModel):
missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
class EntityClassification(BaseModel):
entity_classification: str = Field(
...,
description='Dictionary of entity classifications. Key is the entity name and value is the entity type',
)
class Prompt(Protocol):
extract_message: PromptVersion
extract_json: PromptVersion
extract_text: PromptVersion
reflexion: PromptVersion
classify_nodes: PromptVersion
class Versions(TypedDict):
@ -42,6 +50,7 @@ class Versions(TypedDict):
extract_json: PromptFunction
extract_text: PromptFunction
reflexion: PromptFunction
classify_nodes: PromptFunction
def extract_message(context: dict[str, Any]) -> list[Message]:
@ -66,6 +75,7 @@ Guidelines:
4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
5. Be as explicit as possible in your node names, using full names.
6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
7. Extract preferences as their own nodes
"""
return [
Message(role='system', content=sys_prompt),
@ -109,7 +119,7 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
{context['custom_prompt']}
Given the following text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned:
Given the above text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned:
Guidelines:
1. Extract significant entities, concepts, or actors mentioned in the conversation.
@ -147,9 +157,41 @@ extracted.
]
def classify_nodes(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that classifies entity nodes given the context from which they were extracted"""
user_prompt = f"""
<PREVIOUS MESSAGES>
{json.dumps([ep for ep in context['previous_episodes']], indent=2)}
</PREVIOUS MESSAGES>
<CURRENT MESSAGE>
{context["episode_content"]}
</CURRENT MESSAGE>
<EXTRACTED ENTITIES>
{context['extracted_entities']}
</EXTRACTED ENTITIES>
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>
Given the above conversation, extracted entities, and provided entity types, classify the extracted entities.
Guidelines:
1. Each entity must have exactly one type
2. If none of the provided entity types accurately classify an extracted node, the type should be set to None
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]
versions: Versions = {
'extract_message': extract_message,
'extract_json': extract_json,
'extract_text': extract_text,
'reflexion': reflexion,
'classify_nodes': classify_nodes,
}

View File

@ -24,7 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
class Summary(BaseModel):
summary: str = Field(
..., description='Summary containing the important information from both summaries'
...,
description='Summary containing the important information about the entity. Under 500 words',
)
@ -68,7 +69,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that combines summaries with new conversation context.',
content='You are a helpful assistant that extracts entity properties from the provided text.',
),
Message(
role='user',
@ -79,18 +80,23 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
{json.dumps(context['episode_content'], indent=2)}
</MESSAGES>
Given the above MESSAGES and the following ENTITY name and ENTITY CONTEXT, create a summary for the ENTITY. Your summary must only use
information from the provided MESSAGES and from the ENTITY CONTEXT. Your summary should also only contain information relevant to the
provided ENTITY.
Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
information from the provided MESSAGES. Your summary should also only contain information relevant to the
provided ENTITY. Summaries must be under 500 words.
Summaries must be under 500 words.
In addition, extract any values for the provided entity properties based on their descriptions.
<ENTITY>
{context['node_name']}
</ENTITY>
<ENTITY CONTEXT>
{context['node_summary']}
</ENTITY CONTEXT>
<ATTRIBUTES>
{json.dumps(context['attributes'], indent=2)}
</ATTRIBUTES>
""",
),
]

View File

@ -97,7 +97,9 @@ async def get_mentioned_nodes(
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
""",
uuids=episode_uuids,
database_=DEFAULT_DATABASE,
@ -223,8 +225,8 @@ async def edge_similarity_search(
query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@ -341,7 +343,9 @@ async def node_fulltext_search(
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
ORDER BY score DESC
LIMIT $limit
""",
@ -390,7 +394,9 @@ async def node_similarity_search(
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
ORDER BY score DESC
LIMIT $limit
""",
@ -427,7 +433,9 @@ async def node_bfs_search(
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
LIMIT $limit
""",
bfs_origin_node_uuids=bfs_origin_node_uuids,

View File

@ -23,6 +23,7 @@ from math import ceil
from neo4j import AsyncDriver, AsyncManagedTransaction
from numpy import dot, sqrt
from pydantic import BaseModel
from typing_extensions import Any
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.helpers import semaphore_gather
@ -109,8 +110,23 @@ async def add_nodes_and_edges_bulk_tx(
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
nodes: list[dict[str, Any]] = []
for node in entity_nodes:
entity_data: dict[str, Any] = {
'uuid': node.uuid,
'name': node.name,
'name_embedding': node.name_embedding,
'group_id': node.group_id,
'summary': node.summary,
'created_at': node.created_at,
}
entity_data.update(node.attributes or {})
entity_data['labels'] = list(set(node.labels + ['Entity']))
nodes.append(entity_data)
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes])
await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])

View File

@ -14,15 +14,19 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import ast
import logging
from time import time
import pydantic
from pydantic import BaseModel
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
from graphiti_core.prompts.summarize_nodes import Summary
from graphiti_core.utils.datetime_utils import utc_now
@ -114,6 +118,7 @@ async def extract_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
entity_types: dict[str, BaseModel] | None = None,
) -> list[EntityNode]:
start = time()
extracted_node_names: list[str] = []
@ -144,15 +149,35 @@ async def extract_nodes(
for entity in missing_entities:
custom_prompt += f'\n{entity},'
node_classification_context = {
'episode_content': episode.content,
'previous_episodes': [ep.content for ep in previous_episodes],
'extracted_entities': extracted_node_names,
'entity_types': entity_types.keys() if entity_types is not None else [],
}
node_classifications: dict[str, str | None] = {}
if entity_types is not None:
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.classify_nodes(node_classification_context),
response_model=EntityClassification,
)
response_string = llm_response.get('entity_classification', '{}')
node_classifications.update(ast.literal_eval(response_string))
end = time()
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
# Convert the extracted data into EntityNode objects
new_nodes = []
for name in extracted_node_names:
entity_type = node_classifications.get(name)
labels = ['Entity'] if entity_type is None else ['Entity', entity_type]
new_node = EntityNode(
name=name,
group_id=episode.group_id,
labels=['Entity'],
labels=labels,
summary='',
created_at=utc_now(),
)
@ -218,6 +243,7 @@ async def resolve_extracted_nodes(
existing_nodes_lists: list[list[EntityNode]],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = []
@ -225,7 +251,12 @@ async def resolve_extracted_nodes(
await semaphore_gather(
*[
resolve_extracted_node(
llm_client, extracted_node, existing_nodes, episode, previous_episodes
llm_client,
extracted_node,
existing_nodes,
episode,
previous_episodes,
entity_types,
)
for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists)
]
@ -245,6 +276,7 @@ async def resolve_extracted_node(
existing_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
) -> tuple[EntityNode, dict[str, str]]:
start = time()
@ -273,19 +305,39 @@ async def resolve_extracted_node(
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
else [],
'attributes': [],
}
llm_response, node_summary_response = await semaphore_gather(
entity_type_classes: tuple[BaseModel, ...] = tuple()
if entity_types is not None: # type: ignore
entity_type_classes = entity_type_classes + tuple(
filter(
lambda x: x is not None, # type: ignore
[entity_types.get(entity_type) for entity_type in extracted_node.labels], # type: ignore
)
)
for entity_type in entity_type_classes:
for field_name in entity_type.model_fields:
summary_context.get('attributes', []).append(field_name) # type: ignore
entity_attributes_model = pydantic.create_model( # type: ignore
'EntityAttributes',
__base__=entity_type_classes + (Summary,), # type: ignore
)
llm_response, node_attributes_response = await semaphore_gather(
llm_client.generate_response(
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
),
llm_client.generate_response(
prompt_library.summarize_nodes.summarize_context(summary_context),
response_model=Summary,
response_model=entity_attributes_model,
),
)
extracted_node.summary = node_summary_response.get('summary', '')
extracted_node.summary = node_attributes_response.get('summary', '')
extracted_node.attributes.update(node_attributes_response)
is_duplicate: bool = llm_response.get('is_duplicate', False)
uuid: str | None = llm_response.get('uuid', None)

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.6.1"
version = "0.7.0"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",