From 29a071b2b86d3716dfade3ce44311600e7ef23b1 Mon Sep 17 00:00:00 2001
From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com>
Date: Thu, 13 Feb 2025 12:17:52 -0500
Subject: [PATCH] Custom ontology (#262)
* ontology
* extract and save node labels
* extract entity type properties
* neo4j upgrade needed
* add entity types
* update typing
* update types
* updates
* Update graphiti_core/utils/maintenance/node_operations.py
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
* fix warning
* mypy updates
* update properties
* mypy ignore
* mypy types
* bump version
---------
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
---
.../msc_eval.py | 142 ------------------
.../msc_runner.py | 89 -----------
.../parse_msc_messages.py | 85 -----------
examples/podcast/podcast_runner.py | 8 +
graphiti_core/graphiti.py | 6 +-
graphiti_core/models/nodes/node_db_queries.py | 6 +-
graphiti_core/nodes.py | 37 +++--
graphiti_core/prompts/extract_nodes.py | 44 +++++-
graphiti_core/prompts/summarize_nodes.py | 18 ++-
graphiti_core/search/search_utils.py | 20 ++-
graphiti_core/utils/bulk_utils.py | 18 ++-
.../utils/maintenance/node_operations.py | 64 +++++++-
pyproject.toml | 2 +-
13 files changed, 189 insertions(+), 350 deletions(-)
delete mode 100644 examples/multi_session_conversation_memory/msc_eval.py
delete mode 100644 examples/multi_session_conversation_memory/msc_runner.py
delete mode 100644 examples/multi_session_conversation_memory/parse_msc_messages.py
diff --git a/examples/multi_session_conversation_memory/msc_eval.py b/examples/multi_session_conversation_memory/msc_eval.py
deleted file mode 100644
index db61482b..00000000
--- a/examples/multi_session_conversation_memory/msc_eval.py
+++ /dev/null
@@ -1,142 +0,0 @@
-"""
-Copyright 2024, Zep Software, Inc.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-"""
-
-import asyncio
-import csv
-import logging
-import os
-import sys
-from time import time
-
-from dotenv import load_dotenv
-
-from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a
-from graphiti_core import Graphiti
-from graphiti_core.helpers import semaphore_gather
-from graphiti_core.prompts import prompt_library
-from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
-
-load_dotenv()
-
-neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
-neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
-neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
-
-
-def setup_logging():
- # Create a logger
- logger = logging.getLogger()
- logger.setLevel(logging.INFO) # Set the logging level to INFO
-
- # Create console handler and set level to INFO
- console_handler = logging.StreamHandler(sys.stdout)
- console_handler.setLevel(logging.INFO)
-
- # Create formatter
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-
- # Add formatter to console handler
- console_handler.setFormatter(formatter)
-
- # Add console handler to logger
- logger.addHandler(console_handler)
-
- return logger
-
-
-async def evaluate_qa(graphiti: Graphiti, group_id: str, query: str, answer: str):
- search_start = time()
- results = await graphiti._search(
- query,
- COMBINED_HYBRID_SEARCH_RRF,
- group_ids=[str(group_id)],
- )
- search_end = time()
- search_duration = search_end - search_start
-
- facts = [edge.fact for edge in results.edges]
- entity_summaries = [node.name + ': ' + node.summary for node in results.nodes]
- context = {
- 'facts': facts,
- 'entity_summaries': entity_summaries,
- 'query': 'Bob: ' + query,
- }
-
- llm_response = await graphiti.llm_client.generate_response(
- prompt_library.eval.qa_prompt(context)
- )
- response = llm_response.get('ANSWER', '')
-
- eval_context = {
- 'query': 'Bob: ' + query,
- 'answer': 'Alice: ' + answer,
- 'response': 'Alice: ' + response,
- }
-
- eval_llm_response = await graphiti.llm_client.generate_response(
- prompt_library.eval.eval_prompt(eval_context)
- )
- eval_response = 1 if eval_llm_response.get('is_correct', False) else 0
-
- return {
- 'Group id': group_id,
- 'Question': query,
- 'Answer': answer,
- 'Response': response,
- 'Score': eval_response,
- 'Search Duration (ms)': search_duration * 1000,
- }
-
-
-async def main():
- setup_logging()
- graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
-
- fields = [
- 'Group id',
- 'Question',
- 'Answer',
- 'Response',
- 'Score',
- 'Search Duration (ms)',
- ]
- with open('../data/msc_eval.csv', 'w', newline='') as file:
- writer = csv.DictWriter(file, fieldnames=fields)
- writer.writeheader()
-
- qa = conversation_q_and_a()[0:500]
- i = 0
- while i < 500:
- qa_chunk = qa[i : i + 20]
- group_ids = range(len(qa))[i : i + 20]
- results = list(
- await semaphore_gather(
- *[
- evaluate_qa(graphiti, str(group_id), query, answer)
- for group_id, (query, answer) in zip(group_ids, qa_chunk)
- ]
- )
- )
-
- with open('../data/msc_eval.csv', 'a', newline='') as file:
- writer = csv.DictWriter(file, fieldnames=fields)
- writer.writerows(results)
- i += 20
-
- await graphiti.close()
-
-
-asyncio.run(main())
diff --git a/examples/multi_session_conversation_memory/msc_runner.py b/examples/multi_session_conversation_memory/msc_runner.py
deleted file mode 100644
index 2cef9c58..00000000
--- a/examples/multi_session_conversation_memory/msc_runner.py
+++ /dev/null
@@ -1,89 +0,0 @@
-"""
-Copyright 2024, Zep Software, Inc.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-"""
-
-import asyncio
-import logging
-import os
-import sys
-
-from dotenv import load_dotenv
-
-from examples.multi_session_conversation_memory.parse_msc_messages import (
- ParsedMscMessage,
- parse_msc_messages,
-)
-from graphiti_core import Graphiti
-from graphiti_core.helpers import semaphore_gather
-
-load_dotenv()
-
-neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
-neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
-neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
-
-
-def setup_logging():
- # Create a logger
- logger = logging.getLogger()
- logger.setLevel(logging.INFO) # Set the logging level to INFO
-
- # Create console handler and set level to INFO
- console_handler = logging.StreamHandler(sys.stdout)
- console_handler.setLevel(logging.INFO)
-
- # Create formatter
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-
- # Add formatter to console handler
- console_handler.setFormatter(formatter)
-
- # Add console handler to logger
- logger.addHandler(console_handler)
-
- return logger
-
-
-async def add_conversation(graphiti: Graphiti, group_id: str, messages: list[ParsedMscMessage]):
- for i, message in enumerate(messages):
- await graphiti.add_episode(
- name=f'Message {group_id + "-" + str(i)}',
- episode_body=f'{message.speaker_name}: {message.content}',
- reference_time=message.actual_timestamp,
- source_description='Multi-Session Conversation',
- group_id=group_id,
- )
-
-
-async def main():
- setup_logging()
- graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
- msc_messages = parse_msc_messages()
- i = 0
- while i < len(msc_messages):
- msc_message_slice = msc_messages[i : i + 10]
- group_ids = range(len(msc_messages))[i : i + 10]
-
- await semaphore_gather(
- *[
- add_conversation(graphiti, str(group_id), messages)
- for group_id, messages in zip(group_ids, msc_message_slice)
- ]
- )
-
- i += 10
-
-
-asyncio.run(main())
diff --git a/examples/multi_session_conversation_memory/parse_msc_messages.py b/examples/multi_session_conversation_memory/parse_msc_messages.py
deleted file mode 100644
index 4c5bd219..00000000
--- a/examples/multi_session_conversation_memory/parse_msc_messages.py
+++ /dev/null
@@ -1,85 +0,0 @@
-"""
-Copyright 2024, Zep Software, Inc.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-"""
-
-import json
-from datetime import datetime, timezone
-
-from pydantic import BaseModel
-
-
-class ParsedMscMessage(BaseModel):
- speaker_name: str
- actual_timestamp: datetime
- content: str
- group_id: str
-
-
-def parse_msc_messages() -> list[list[ParsedMscMessage]]:
- msc_messages: list[list[ParsedMscMessage]] = []
- speakers = ['Alice', 'Bob']
-
- with open('../data/msc.jsonl') as file:
- data = [json.loads(line) for line in file]
- for i, conversation in enumerate(data):
- messages: list[ParsedMscMessage] = []
- for previous_dialog in conversation['previous_dialogs']:
- dialog = previous_dialog['dialog']
- speaker_idx = 0
-
- for utterance in dialog:
- content = utterance['text']
- messages.append(
- ParsedMscMessage(
- speaker_name=speakers[speaker_idx],
- content=content,
- actual_timestamp=datetime.now(timezone.utc),
- group_id=str(i),
- )
- )
- speaker_idx += 1
- speaker_idx %= 2
-
- dialog = conversation['dialog']
- speaker_idx = 0
- for utterance in dialog:
- content = utterance['text']
- messages.append(
- ParsedMscMessage(
- speaker_name=speakers[speaker_idx],
- content=content,
- actual_timestamp=datetime.now(timezone.utc),
- group_id=str(i),
- )
- )
- speaker_idx += 1
- speaker_idx %= 2
-
- msc_messages.append(messages)
-
- return msc_messages
-
-
-def conversation_q_and_a() -> list[tuple[str, str]]:
- with open('../data/msc.jsonl') as file:
- data = [json.loads(line) for line in file]
-
- qa: list[tuple[str, str]] = []
- for conversation in data:
- query = conversation['self_instruct']['B']
- answer = conversation['self_instruct']['A']
-
- qa.append((query, answer))
- return qa
diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py
index 0ee01eb2..d5511375 100644
--- a/examples/podcast/podcast_runner.py
+++ b/examples/podcast/podcast_runner.py
@@ -20,6 +20,7 @@ import os
import sys
from dotenv import load_dotenv
+from pydantic import BaseModel, Field
from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti
@@ -53,6 +54,12 @@ def setup_logging():
return logger
+class Person(BaseModel):
+ first_name: str | None = Field(..., description='First name')
+ last_name: str | None = Field(..., description='Last name')
+ occupation: str | None = Field(..., description="The person's work occupation")
+
+
async def main():
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
@@ -67,6 +74,7 @@ async def main():
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='podcast',
+ entity_types={'Person': Person},
)
diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py
index 81cf90fc..df12d98d 100644
--- a/graphiti_core/graphiti.py
+++ b/graphiti_core/graphiti.py
@@ -262,6 +262,7 @@ class Graphiti:
group_id: str = '',
uuid: str | None = None,
update_communities: bool = False,
+ entity_types: dict[str, BaseModel] | None = None,
) -> AddEpisodeResults:
"""
Process an episode and update the graph.
@@ -336,7 +337,9 @@ class Graphiti:
# Extract entities as nodes
- extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
+ extracted_nodes = await extract_nodes(
+ self.llm_client, episode, previous_episodes, entity_types
+ )
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
# Calculate Embeddings
@@ -362,6 +365,7 @@ class Graphiti:
existing_nodes_lists,
episode,
previous_episodes,
+ entity_types,
),
extract_edges(
self.llm_client, episode, extracted_nodes, previous_episodes, group_id
diff --git a/graphiti_core/models/nodes/node_db_queries.py b/graphiti_core/models/nodes/node_db_queries.py
index 9010532b..a34c5cb5 100644
--- a/graphiti_core/models/nodes/node_db_queries.py
+++ b/graphiti_core/models/nodes/node_db_queries.py
@@ -31,14 +31,16 @@ EPISODIC_NODE_SAVE_BULK = """
ENTITY_NODE_SAVE = """
MERGE (n:Entity {uuid: $uuid})
- SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
+ SET n:$($labels)
+ SET n = $entity_data
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid"""
ENTITY_NODE_SAVE_BULK = """
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
- SET n = {uuid: node.uuid, name: node.name, group_id: node.group_id, summary: node.summary, created_at: node.created_at}
+ SET n:$(node.labels)
+ SET n = node
WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
RETURN n.uuid AS uuid
"""
diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py
index 6a490c8c..508a0d4a 100644
--- a/graphiti_core/nodes.py
+++ b/graphiti_core/nodes.py
@@ -255,6 +255,9 @@ class EpisodicNode(Node):
class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
+ attributes: dict[str, Any] = Field(
+ default={}, description='Additional attributes of the node. Dependent on node labels'
+ )
async def generate_name_embedding(self, embedder: EmbedderClient):
start = time()
@@ -266,14 +269,21 @@ class EntityNode(Node):
return self.name_embedding
async def save(self, driver: AsyncDriver):
+ entity_data: dict[str, Any] = {
+ 'uuid': self.uuid,
+ 'name': self.name,
+ 'name_embedding': self.name_embedding,
+ 'group_id': self.group_id,
+ 'summary': self.summary,
+ 'created_at': self.created_at,
+ }
+
+ entity_data.update(self.attributes or {})
+
result = await driver.execute_query(
ENTITY_NODE_SAVE,
- uuid=self.uuid,
- name=self.name,
- group_id=self.group_id,
- summary=self.summary,
- name_embedding=self.name_embedding,
- created_at=self.created_at,
+ labels=self.labels + ['Entity'],
+ entity_data=entity_data,
database_=DEFAULT_DATABASE,
)
@@ -292,7 +302,9 @@ class EntityNode(Node):
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
- n.summary AS summary
+ n.summary AS summary,
+ labels(n) AS labels,
+ properties(n) AS attributes
""",
uuid=uuid,
database_=DEFAULT_DATABASE,
@@ -317,7 +329,9 @@ class EntityNode(Node):
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
- n.summary AS summary
+ n.summary AS summary,
+ labels(n) AS labels,
+ properties(n) AS attributes
""",
uuids=uuids,
database_=DEFAULT_DATABASE,
@@ -351,7 +365,9 @@ class EntityNode(Node):
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
- n.summary AS summary
+ n.summary AS summary,
+ labels(n) AS labels,
+ properties(n) AS attributes
ORDER BY n.uuid DESC
"""
+ limit_query,
@@ -503,9 +519,10 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
name=record['name'],
group_id=record['group_id'],
name_embedding=record['name_embedding'],
- labels=['Entity'],
+ labels=record['labels'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
+ attributes=record['attributes'],
)
diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py
index 49e2036b..ebbaec87 100644
--- a/graphiti_core/prompts/extract_nodes.py
+++ b/graphiti_core/prompts/extract_nodes.py
@@ -30,11 +30,19 @@ class MissedEntities(BaseModel):
missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
+class EntityClassification(BaseModel):
+ entity_classification: str = Field(
+ ...,
+ description='Dictionary of entity classifications. Key is the entity name and value is the entity type',
+ )
+
+
class Prompt(Protocol):
extract_message: PromptVersion
extract_json: PromptVersion
extract_text: PromptVersion
reflexion: PromptVersion
+ classify_nodes: PromptVersion
class Versions(TypedDict):
@@ -42,6 +50,7 @@ class Versions(TypedDict):
extract_json: PromptFunction
extract_text: PromptFunction
reflexion: PromptFunction
+ classify_nodes: PromptFunction
def extract_message(context: dict[str, Any]) -> list[Message]:
@@ -66,6 +75,7 @@ Guidelines:
4. DO NOT create nodes for temporal information like dates, times or years (these will be added to edges later).
5. Be as explicit as possible in your node names, using full names.
6. DO NOT extract entities mentioned only in PREVIOUS MESSAGES, those messages are only to provide context.
+7. Extract preferences as their own nodes
"""
return [
Message(role='system', content=sys_prompt),
@@ -109,7 +119,7 @@ def extract_text(context: dict[str, Any]) -> list[Message]:
{context['custom_prompt']}
-Given the following text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned:
+Given the above text, extract entity nodes from the TEXT that are explicitly or implicitly mentioned:
Guidelines:
1. Extract significant entities, concepts, or actors mentioned in the conversation.
@@ -147,9 +157,41 @@ extracted.
]
+def classify_nodes(context: dict[str, Any]) -> list[Message]:
+ sys_prompt = """You are an AI assistant that classifies entity nodes given the context from which they were extracted"""
+
+ user_prompt = f"""
+
+ {json.dumps([ep for ep in context['previous_episodes']], indent=2)}
+
+
+ {context["episode_content"]}
+
+
+
+ {context['extracted_entities']}
+
+
+
+ {context['entity_types']}
+
+
+ Given the above conversation, extracted entities, and provided entity types, classify the extracted entities.
+
+ Guidelines:
+ 1. Each entity must have exactly one type
+ 2. If none of the provided entity types accurately classify an extracted node, the type should be set to None
+"""
+ return [
+ Message(role='system', content=sys_prompt),
+ Message(role='user', content=user_prompt),
+ ]
+
+
versions: Versions = {
'extract_message': extract_message,
'extract_json': extract_json,
'extract_text': extract_text,
'reflexion': reflexion,
+ 'classify_nodes': classify_nodes,
}
diff --git a/graphiti_core/prompts/summarize_nodes.py b/graphiti_core/prompts/summarize_nodes.py
index e00e1bab..0a880a82 100644
--- a/graphiti_core/prompts/summarize_nodes.py
+++ b/graphiti_core/prompts/summarize_nodes.py
@@ -24,7 +24,8 @@ from .models import Message, PromptFunction, PromptVersion
class Summary(BaseModel):
summary: str = Field(
- ..., description='Summary containing the important information from both summaries'
+ ...,
+ description='Summary containing the important information about the entity. Under 500 words',
)
@@ -68,7 +69,7 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
- content='You are a helpful assistant that combines summaries with new conversation context.',
+ content='You are a helpful assistant that extracts entity properties from the provided text.',
),
Message(
role='user',
@@ -79,18 +80,23 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
{json.dumps(context['episode_content'], indent=2)}
- Given the above MESSAGES and the following ENTITY name and ENTITY CONTEXT, create a summary for the ENTITY. Your summary must only use
- information from the provided MESSAGES and from the ENTITY CONTEXT. Your summary should also only contain information relevant to the
- provided ENTITY.
+ Given the above MESSAGES and the following ENTITY name, create a summary for the ENTITY. Your summary must only use
+ information from the provided MESSAGES. Your summary should also only contain information relevant to the
+ provided ENTITY. Summaries must be under 500 words.
- Summaries must be under 500 words.
+ In addition, extract any values for the provided entity properties based on their descriptions.
{context['node_name']}
+
{context['node_summary']}
+
+
+ {json.dumps(context['attributes'], indent=2)}
+
""",
),
]
diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py
index c4d44fd1..ef9b6eb2 100644
--- a/graphiti_core/search/search_utils.py
+++ b/graphiti_core/search/search_utils.py
@@ -97,7 +97,9 @@ async def get_mentioned_nodes(
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
- n.summary AS summary
+ n.summary AS summary,
+ labels(n) AS labels,
+ properties(n) AS attributes
""",
uuids=episode_uuids,
database_=DEFAULT_DATABASE,
@@ -223,8 +225,8 @@ async def edge_similarity_search(
query: LiteralString = (
"""
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
- """
+ MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
+ """
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@@ -341,7 +343,9 @@ async def node_fulltext_search(
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
- n.summary AS summary
+ n.summary AS summary,
+ labels(n) AS labels,
+ properties(n) AS attributes
ORDER BY score DESC
LIMIT $limit
""",
@@ -390,7 +394,9 @@ async def node_similarity_search(
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
- n.summary AS summary
+ n.summary AS summary,
+ labels(n) AS labels,
+ properties(n) AS attributes
ORDER BY score DESC
LIMIT $limit
""",
@@ -427,7 +433,9 @@ async def node_bfs_search(
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
- n.summary AS summary
+ n.summary AS summary,
+ labels(n) AS labels,
+ properties(n) AS attributes
LIMIT $limit
""",
bfs_origin_node_uuids=bfs_origin_node_uuids,
diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py
index 80f66029..b2340d63 100644
--- a/graphiti_core/utils/bulk_utils.py
+++ b/graphiti_core/utils/bulk_utils.py
@@ -23,6 +23,7 @@ from math import ceil
from neo4j import AsyncDriver, AsyncManagedTransaction
from numpy import dot, sqrt
from pydantic import BaseModel
+from typing_extensions import Any
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.helpers import semaphore_gather
@@ -109,8 +110,23 @@ async def add_nodes_and_edges_bulk_tx(
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
+ nodes: list[dict[str, Any]] = []
+ for node in entity_nodes:
+ entity_data: dict[str, Any] = {
+ 'uuid': node.uuid,
+ 'name': node.name,
+ 'name_embedding': node.name_embedding,
+ 'group_id': node.group_id,
+ 'summary': node.summary,
+ 'created_at': node.created_at,
+ }
+
+ entity_data.update(node.attributes or {})
+ entity_data['labels'] = list(set(node.labels + ['Entity']))
+ nodes.append(entity_data)
+
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
- await tx.run(ENTITY_NODE_SAVE_BULK, nodes=[dict(entity) for entity in entity_nodes])
+ await tx.run(ENTITY_NODE_SAVE_BULK, nodes=nodes)
await tx.run(EPISODIC_EDGE_SAVE_BULK, episodic_edges=[dict(edge) for edge in episodic_edges])
await tx.run(ENTITY_EDGE_SAVE_BULK, entity_edges=[dict(edge) for edge in entity_edges])
diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py
index 31e916b4..3d4f6bf3 100644
--- a/graphiti_core/utils/maintenance/node_operations.py
+++ b/graphiti_core/utils/maintenance/node_operations.py
@@ -14,15 +14,19 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
+import ast
import logging
from time import time
+import pydantic
+from pydantic import BaseModel
+
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
-from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities
+from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
from graphiti_core.prompts.summarize_nodes import Summary
from graphiti_core.utils.datetime_utils import utc_now
@@ -114,6 +118,7 @@ async def extract_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
+ entity_types: dict[str, BaseModel] | None = None,
) -> list[EntityNode]:
start = time()
extracted_node_names: list[str] = []
@@ -144,15 +149,35 @@ async def extract_nodes(
for entity in missing_entities:
custom_prompt += f'\n{entity},'
+ node_classification_context = {
+ 'episode_content': episode.content,
+ 'previous_episodes': [ep.content for ep in previous_episodes],
+ 'extracted_entities': extracted_node_names,
+ 'entity_types': entity_types.keys() if entity_types is not None else [],
+ }
+
+ node_classifications: dict[str, str | None] = {}
+
+ if entity_types is not None:
+ llm_response = await llm_client.generate_response(
+ prompt_library.extract_nodes.classify_nodes(node_classification_context),
+ response_model=EntityClassification,
+ )
+ response_string = llm_response.get('entity_classification', '{}')
+ node_classifications.update(ast.literal_eval(response_string))
+
end = time()
logger.debug(f'Extracted new nodes: {extracted_node_names} in {(end - start) * 1000} ms')
# Convert the extracted data into EntityNode objects
new_nodes = []
for name in extracted_node_names:
+ entity_type = node_classifications.get(name)
+ labels = ['Entity'] if entity_type is None else ['Entity', entity_type]
+
new_node = EntityNode(
name=name,
group_id=episode.group_id,
- labels=['Entity'],
+ labels=labels,
summary='',
created_at=utc_now(),
)
@@ -218,6 +243,7 @@ async def resolve_extracted_nodes(
existing_nodes_lists: list[list[EntityNode]],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
+ entity_types: dict[str, BaseModel] | None = None,
) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = []
@@ -225,7 +251,12 @@ async def resolve_extracted_nodes(
await semaphore_gather(
*[
resolve_extracted_node(
- llm_client, extracted_node, existing_nodes, episode, previous_episodes
+ llm_client,
+ extracted_node,
+ existing_nodes,
+ episode,
+ previous_episodes,
+ entity_types,
)
for extracted_node, existing_nodes in zip(extracted_nodes, existing_nodes_lists)
]
@@ -245,6 +276,7 @@ async def resolve_extracted_node(
existing_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
+ entity_types: dict[str, BaseModel] | None = None,
) -> tuple[EntityNode, dict[str, str]]:
start = time()
@@ -273,19 +305,39 @@ async def resolve_extracted_node(
'previous_episodes': [ep.content for ep in previous_episodes]
if previous_episodes is not None
else [],
+ 'attributes': [],
}
- llm_response, node_summary_response = await semaphore_gather(
+ entity_type_classes: tuple[BaseModel, ...] = tuple()
+ if entity_types is not None: # type: ignore
+ entity_type_classes = entity_type_classes + tuple(
+ filter(
+ lambda x: x is not None, # type: ignore
+ [entity_types.get(entity_type) for entity_type in extracted_node.labels], # type: ignore
+ )
+ )
+
+ for entity_type in entity_type_classes:
+ for field_name in entity_type.model_fields:
+ summary_context.get('attributes', []).append(field_name) # type: ignore
+
+ entity_attributes_model = pydantic.create_model( # type: ignore
+ 'EntityAttributes',
+ __base__=entity_type_classes + (Summary,), # type: ignore
+ )
+
+ llm_response, node_attributes_response = await semaphore_gather(
llm_client.generate_response(
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
),
llm_client.generate_response(
prompt_library.summarize_nodes.summarize_context(summary_context),
- response_model=Summary,
+ response_model=entity_attributes_model,
),
)
- extracted_node.summary = node_summary_response.get('summary', '')
+ extracted_node.summary = node_attributes_response.get('summary', '')
+ extracted_node.attributes.update(node_attributes_response)
is_duplicate: bool = llm_response.get('is_duplicate', False)
uuid: str | None = llm_response.get('uuid', None)
diff --git a/pyproject.toml b/pyproject.toml
index aa0a52f1..790a3a72 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
-version = "0.6.1"
+version = "0.7.0"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk ",