chore: Fix Typing Issues (#27)

* typing.Any and friends

* message

* chore: Import Message model in llm_client

* fix: 💄 mypy errors

* clean up mypy stuff

* mypy

* format

* mypy

* mypy

* mypy

---------

Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
Co-authored-by: prestonrasmussen <prasmuss15@gmail.com>
This commit is contained in:
Daniel Chalef 2024-08-23 08:15:44 -07:00 committed by GitHub
parent 7152a211ae
commit 9cc9883e66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 134 additions and 587 deletions

View File

@ -22,7 +22,7 @@ format:
# Lint code
lint:
$(RUFF) check
$(MYPY) . --show-column-numbers --show-error-codes --pretty
$(MYPY) ./core --show-column-numbers --show-error-codes --pretty
# Run tests
test:

View File

@ -56,7 +56,7 @@ class Graphiti:
else:
self.llm_client = OpenAIClient(
LLMConfig(
api_key=os.getenv('OPENAI_API_KEY'),
api_key=os.getenv('OPENAI_API_KEY', default=''),
model='gpt-4o-mini',
base_url='https://api.openai.com/v1',
)
@ -72,28 +72,16 @@ class Graphiti:
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
sources: list[str] | None = 'messages',
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
return await retrieve_episodes(self.driver, reference_time, last_n, sources)
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
return await retrieve_episodes(self.driver, reference_time, last_n)
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime | None = None,
episode_type: str | None = 'string', # TODO: this field isn't used yet?
reference_time: datetime,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
@ -104,7 +92,7 @@ class Graphiti:
nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.client.embeddings
embedder = self.llm_client.get_embedder()
now = datetime.now()
previous_episodes = await self.retrieve_episodes(reference_time)
@ -234,7 +222,7 @@ class Graphiti:
):
try:
start = time()
embedder = self.llm_client.client.embeddings
embedder = self.llm_client.get_embedder()
now = datetime.now()
episodes = [
@ -276,14 +264,22 @@ class Graphiti:
await asyncio.gather(*[node.save(self.driver) for node in nodes])
# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges: list[EntityEdge] = resolve_edge_pointers(extracted_edges, uuid_map)
episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(episodic_edges, uuid_map)
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
extracted_edges, uuid_map
)
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
episodic_edges, uuid_map
)
# save episodic edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
)
# Dedupe extracted edges
edges = await dedupe_edges_bulk(self.driver, self.llm_client, extracted_edges)
edges = await dedupe_edges_bulk(
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
)
logger.info(f'extracted edge length: {len(edges)}')
# invalidate edges
@ -302,12 +298,12 @@ class Graphiti:
edges = (
await hybrid_search(
self.driver,
self.llm_client.client.embeddings,
self.llm_client.get_embedder(),
query,
datetime.now(),
search_config,
)
)['edges']
).edges
facts = [edge.fact for edge in edges]
@ -315,5 +311,5 @@ class Graphiti:
async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
return await hybrid_search(
self.driver, self.llm_client.client.embeddings, query, timestamp, config
self.driver, self.llm_client.get_embedder(), query, timestamp, config
)

View File

@ -1,5 +1,7 @@
import typing
from abc import ABC, abstractmethod
from ..prompts.models import Message
from .config import LLMConfig
@ -9,5 +11,9 @@ class LLMClient(ABC):
pass
@abstractmethod
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
def get_embedder(self) -> typing.Any:
pass
@abstractmethod
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
pass

View File

@ -1,8 +1,11 @@
import json
import logging
import typing
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
@ -14,16 +17,26 @@ class OpenAIClient(LLMClient):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
self.model = config.model
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
def get_embedder(self) -> typing.Any:
return self.client.embeddings
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content})
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
messages=openai_messages,
temperature=0.1,
max_tokens=3000,
response_format={'type': 'json_object'},
)
return json.loads(response.choices[0].message.content)
result = response.choices[0].message.content or ''
return json.loads(result)
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise

View File

@ -1,5 +1,5 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict
from .models import Message, PromptFunction, PromptVersion
@ -7,6 +7,7 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
edge_list: PromptVersion
class Versions(TypedDict):
@ -15,7 +16,7 @@ class Versions(TypedDict):
edge_list: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
@ -55,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]
def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
@ -97,7 +98,7 @@ def v2(context: dict[str, any]) -> list[Message]:
]
def edge_list(context: dict[str, any]) -> list[Message]:
def edge_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',

View File

@ -1,5 +1,5 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict
from .models import Message, PromptFunction, PromptVersion
@ -16,7 +16,7 @@ class Versions(TypedDict):
node_list: PromptVersion
def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
@ -56,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]
def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
@ -96,7 +96,7 @@ def v2(context: dict[str, any]) -> list[Message]:
]
def node_list(context: dict[str, any]) -> list[Message]:
def node_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',

View File

@ -1,5 +1,5 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict
from .models import Message, PromptFunction, PromptVersion
@ -14,7 +14,7 @@ class Versions(TypedDict):
v2: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
@ -70,7 +70,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]
def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',

View File

@ -1,5 +1,5 @@
import json
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict
from .models import Message, PromptFunction, PromptVersion
@ -16,7 +16,7 @@ class Versions(TypedDict):
v3: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
@ -64,7 +64,7 @@ def v1(context: dict[str, any]) -> list[Message]:
]
def v2(context: dict[str, any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
@ -105,7 +105,7 @@ def v2(context: dict[str, any]) -> list[Message]:
]
def v3(context: dict[str, any]) -> list[Message]:
def v3(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
user_prompt = f"""

View File

@ -1,4 +1,4 @@
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict
from .models import Message, PromptFunction, PromptVersion
@ -11,7 +11,7 @@ class Versions(TypedDict):
v1: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',

View File

@ -1,4 +1,4 @@
from typing import Protocol, TypedDict
from typing import Any, Protocol, TypedDict
from .dedupe_edges import (
Prompt as DedupeEdgesPrompt,
@ -68,7 +68,7 @@ class VersionWrapper:
def __init__(self, func: PromptFunction):
self.func = func
def __call__(self, context: dict[str, any]) -> list[Message]:
def __call__(self, context: dict[str, Any]) -> list[Message]:
return self.func(context)
@ -81,7 +81,7 @@ class PromptTypeWrapper:
class PromptLibraryWrapper:
def __init__(self, library: PromptLibraryImpl):
for prompt_type, versions in library.items():
setattr(self, prompt_type, PromptTypeWrapper(versions))
setattr(self, prompt_type, PromptTypeWrapper(versions)) # type: ignore[arg-type]
PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
@ -91,5 +91,4 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
'dedupe_edges': dedupe_edges_versions,
'invalidate_edges': invalidate_edges_versions,
}
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]

View File

@ -1,4 +1,4 @@
from typing import Callable, Protocol
from typing import Any, Callable, Protocol
from pydantic import BaseModel
@ -9,7 +9,7 @@ class Message(BaseModel):
class PromptVersion(Protocol):
def __call__(self, context: dict[str, any]) -> list[Message]: ...
def __call__(self, context: dict[str, Any]) -> list[Message]: ...
PromptFunction = Callable[[dict[str, any]], list[Message]]
PromptFunction = Callable[[dict[str, Any]], list[Message]]

View File

@ -5,9 +5,9 @@ from time import time
from neo4j import AsyncDriver
from pydantic import BaseModel
from core.edges import Edge
from core.edges import EntityEdge
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node
from core.nodes import EntityNode, EpisodicNode
from core.search.search_utils import (
edge_fulltext_search,
edge_similarity_search,
@ -28,9 +28,15 @@ class SearchConfig(BaseModel):
reranker: str = 'rrf'
class SearchResults(BaseModel):
episodes: list[EpisodicNode]
nodes: list[EntityNode]
edges: list[EntityEdge]
async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
) -> dict[str, [Node | Edge]]:
) -> SearchResults:
start = time()
episodes = []
@ -86,11 +92,7 @@ async def hybrid_search(
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)
context = {
'episodes': episodes,
'nodes': nodes,
'edges': edges,
}
context = SearchResults(episodes=episodes, nodes=nodes, edges=edges)
end = time()

View File

@ -1,5 +1,6 @@
import asyncio
import logging
import typing
from collections import defaultdict
from datetime import datetime
from time import time
@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None:
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
return neo_date.to_native() if neo_date else None
@ -41,7 +42,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
uuid=record['uuid'],
name=record['name'],
labels=['Entity'],
created_at=datetime.now(),
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
@ -74,7 +75,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
node_ids=node_ids,
)
context = {}
context: dict[str, typing.Any] = {}
for record in records:
n_uuid = record['source_node_uuid']
@ -173,7 +174,7 @@ async def entity_similarity_search(
uuid=record['uuid'],
name=record['name'],
labels=['Entity'],
created_at=datetime.now(),
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
@ -208,7 +209,7 @@ async def entity_fulltext_search(
uuid=record['uuid'],
name=record['name'],
labels=['Entity'],
created_at=datetime.now(),
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
)
@ -277,7 +278,11 @@ async def get_relevant_nodes(
results = await asyncio.gather(
*[entity_fulltext_search(node.name, driver) for node in nodes],
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
*[
entity_similarity_search(node.name_embedding, driver)
for node in nodes
if node.name_embedding is not None
],
)
for result in results:
@ -303,7 +308,11 @@ async def get_relevant_edges(
relevant_edge_uuids = set()
results = await asyncio.gather(
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
*[
edge_similarity_search(edge.fact_embedding, driver)
for edge in edges
if edge.fact_embedding is not None
],
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
)

View File

@ -1,15 +1,15 @@
from .maintenance import (
build_episodic_edges,
clear_data,
extract_new_edges,
extract_new_nodes,
extract_edges,
extract_nodes,
retrieve_episodes,
)
__all__ = [
'extract_new_edges',
'extract_edges',
'build_episodic_edges',
'extract_new_nodes',
'extract_nodes',
'clear_data',
'retrieve_episodes',
]

View File

@ -1,4 +1,5 @@
import asyncio
import typing
from datetime import datetime
from neo4j import AsyncDriver
@ -121,8 +122,8 @@ async def dedupe_edges_bulk(
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map = {}
name_map = {}
uuid_map: dict[str, str] = {}
name_map: dict[str, EntityNode] = {}
for node in nodes:
if node.name in name_map:
uuid_map[node.uuid] = name_map[node.name].uuid
@ -182,7 +183,10 @@ def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
return compressed_map
def resolve_edge_pointers(edges: list[Edge], uuid_map: dict[str, str]):
E = typing.TypeVar('E', bound=Edge)
def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
for edge in edges:
source_uuid = edge.source_node_uuid
target_uuid = edge.target_node_uuid

View File

@ -1,15 +1,15 @@
from .edge_operations import build_episodic_edges, extract_new_edges
from .edge_operations import build_episodic_edges, extract_edges
from .graph_data_operations import (
clear_data,
retrieve_episodes,
)
from .node_operations import extract_new_nodes
from .node_operations import extract_nodes
from .temporal_operations import invalidate_edges
__all__ = [
'extract_new_edges',
'extract_edges',
'build_episodic_edges',
'extract_new_nodes',
'extract_nodes',
'clear_data',
'retrieve_episodes',
'invalidate_edges',

View File

@ -1,4 +1,3 @@
import json
import logging
from datetime import datetime
from time import time
@ -8,7 +7,6 @@ from core.edges import EntityEdge, EpisodicEdge
from core.llm_client import LLMClient
from core.nodes import EntityNode, EpisodicNode
from core.prompts import prompt_library
from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet
logger = logging.getLogger(__name__)
@ -31,103 +29,6 @@ def build_episodic_edges(
return edges
async def extract_new_edges(
llm_client: LLMClient,
episode: EpisodicNode,
new_nodes: list[EntityNode],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> tuple[list[EntityEdge], list[EntityNode]]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
'relevant_schema': json.dumps(relevant_schema, indent=2),
'new_nodes': [{'name': node.name, 'summary': node.summary} for node in new_nodes],
'previous_episodes': [
{
'content': ep.content,
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(prompt_library.extract_edges.v1(context))
new_edges_data = llm_response.get('new_edges', [])
logger.info(f'Extracted new edges: {new_edges_data}')
# Convert the extracted data into EntityEdge objects
new_edges = []
for edge_data in new_edges_data:
source_node = next(
(node for node in new_nodes if node.name == edge_data['source_node']),
None,
)
target_node = next(
(node for node in new_nodes if node.name == edge_data['target_node']),
None,
)
# If source or target is not in new_nodes, check if it's an existing node
if source_node is None and edge_data['source_node'] in relevant_schema['nodes']:
existing_node_data = relevant_schema['nodes'][edge_data['source_node']]
source_node = EntityNode(
uuid=existing_node_data['uuid'],
name=edge_data['source_node'],
labels=[existing_node_data['label']],
summary='',
created_at=datetime.now(),
)
if target_node is None and edge_data['target_node'] in relevant_schema['nodes']:
existing_node_data = relevant_schema['nodes'][edge_data['target_node']]
target_node = EntityNode(
uuid=existing_node_data['uuid'],
name=edge_data['target_node'],
labels=[existing_node_data['label']],
summary='',
created_at=datetime.now(),
)
if (
source_node
and target_node
and not (
source_node.name.startswith('Message') or target_node.name.startswith('Message')
)
):
valid_at = (
datetime.fromisoformat(edge_data['valid_at'])
if edge_data['valid_at']
else episode.valid_at or datetime.now()
)
invalid_at = (
datetime.fromisoformat(edge_data['invalid_at']) if edge_data['invalid_at'] else None
)
new_edge = EntityEdge(
source_node=source_node,
target_node=target_node,
name=edge_data['relation_type'],
fact=edge_data['fact'],
episodes=[episode.uuid],
created_at=datetime.now(),
valid_at=valid_at,
invalid_at=invalid_at,
)
new_edges.append(new_edge)
logger.info(
f'Created new edge: {new_edge.name} from {source_node.name} (UUID: {source_node.uuid}) to {target_node.name} (UUID: {target_node.uuid})'
)
affected_nodes = set()
for edge in new_edges:
affected_nodes.add(edge.source_node)
affected_nodes.add(edge.target_node)
return new_edges, list(affected_nodes)
async def extract_edges(
llm_client: LLMClient,
episode: EpisodicNode,
@ -186,45 +87,6 @@ def create_edge_identifier(
return f'{source_node.name}-{edge.name}-{target_node.name}'
async def dedupe_extracted_edges_v2(
llm_client: LLMClient,
extracted_edges: list[NodeEdgeNodeTriplet],
existing_edges: list[NodeEdgeNodeTriplet],
) -> list[NodeEdgeNodeTriplet]:
# Create edge map
edge_map = {}
for n1, edge, n2 in existing_edges:
edge_map[create_edge_identifier(n1, edge, n2)] = edge
for n1, edge, n2 in extracted_edges:
if create_edge_identifier(n1, edge, n2) in edge_map:
continue
edge_map[create_edge_identifier(n1, edge, n2)] = edge
# Prepare context for LLM
context = {
'extracted_edges': [
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
for n1, edge, n2 in extracted_edges
],
'existing_edges': [
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
for n1, edge, n2 in extracted_edges
],
}
logger.info(prompt_library.dedupe_edges.v2(context))
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v2(context))
new_edges_data = llm_response.get('new_edges', [])
logger.info(f'Extracted new edges: {new_edges_data}')
# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data['triplet']]
edges.append(edge)
return edges
async def dedupe_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],

View File

@ -52,9 +52,7 @@ async def build_indices_and_constraints(driver: AsyncDriver):
}}
""",
]
index_queries: list[LiteralString] = (
range_indices + fulltext_indices + vector_indices
)
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
@ -72,7 +70,6 @@ async def retrieve_episodes(
driver: AsyncDriver,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
sources: list[str] | None = 'messages',
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
result = await driver.execute_query(
@ -97,14 +94,7 @@ async def retrieve_episodes(
created_at=datetime.fromtimestamp(
record['created_at'].to_native().timestamp(), timezone.utc
),
valid_at=(
datetime.fromtimestamp(
record['valid_at'].to_native().timestamp(),
timezone.utc,
)
if record['valid_at'] is not None
else None
),
valid_at=(record['valid_at'].to_native()),
uuid=record['uuid'],
source=record['source'],
name=record['name'],

View File

@ -9,53 +9,6 @@ from core.prompts import prompt_library
logger = logging.getLogger(__name__)
async def extract_new_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[EntityNode]:
# Prepare context for LLM
existing_nodes = [
{'name': node_name, 'label': node_info['label'], 'uuid': node_info['uuid']}
for node_name, node_info in relevant_schema['nodes'].items()
]
context = {
'episode_content': episode.content,
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
'existing_nodes': existing_nodes,
'previous_episodes': [
{
'content': ep.content,
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v1(context))
new_nodes_data = llm_response.get('new_nodes', [])
logger.info(f'Extracted new nodes: {new_nodes_data}')
# Convert the extracted data into EntityNode objects
new_nodes = []
for node_data in new_nodes_data:
# Check if the node already exists
if not any(existing_node['name'] == node_data['name'] for existing_node in existing_nodes):
new_node = EntityNode(
name=node_data['name'],
labels=node_data['labels'],
summary=node_data['summary'],
created_at=datetime.now(),
)
new_nodes.append(new_node)
logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
else:
logger.info(f"Node {node_data['name']} already exists, skipping creation.")
return new_nodes
async def extract_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
@ -100,16 +53,16 @@ async def dedupe_extracted_nodes(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
existing_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]:
start = time()
# build existing node map
node_map = {}
node_map: dict[str, EntityNode] = {}
for node in existing_nodes:
node_map[node.name] = node
# Temp hack
new_nodes_map = {}
new_nodes_map: dict[str, EntityNode] = {}
for node in extracted_nodes:
new_nodes_map[node.name] = node
@ -134,14 +87,14 @@ async def dedupe_extracted_nodes(
end = time()
logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
uuid_map = {}
uuid_map: dict[str, str] = {}
for duplicate in duplicate_data:
uuid = new_nodes_map[duplicate['name']].uuid
uuid_value = node_map[duplicate['duplicate_of']].uuid
uuid_map[uuid] = uuid_value
nodes = []
brand_new_nodes = []
nodes: list[EntityNode] = []
brand_new_nodes: list[EntityNode] = []
for node in extracted_nodes:
if node.uuid in uuid_map:
existing_uuid = uuid_map[node.uuid]
@ -149,7 +102,9 @@ async def dedupe_extracted_nodes(
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
nodes.append(existing_node)
if existing_node:
nodes.append(existing_node)
continue
brand_new_nodes.append(node)
nodes.append(node)

View File

@ -23,6 +23,8 @@ def extract_node_edge_node_triplet(
) -> NodeEdgeNodeTriplet:
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
if not source_node or not target_node:
raise ValueError(f'Source or target node not found for edge {edge.uuid}')
return (source_node, edge, target_node)
@ -31,11 +33,8 @@ def prepare_edges_for_invalidation(
new_edges: list[EntityEdge],
nodes: list[EntityNode],
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
existing_edges_pending_invalidation = [] # TODO: this is not yet used?
new_edges_with_nodes = [] # TODO: this is not yet used?
existing_edges_pending_invalidation = []
new_edges_with_nodes = []
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = []
new_edges_with_nodes: list[NodeEdgeNodeTriplet] = []
for edge_list, result_list in [
(existing_edges, existing_edges_pending_invalidation),

View File

@ -1,292 +0,0 @@
import asyncio
import logging
from datetime import datetime
from time import time
from neo4j import AsyncDriver
from neo4j import time as neo4j_time
from core.edges import EntityEdge
from core.nodes import EntityNode
logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
"""
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
RETURN
n.uuid AS source_node_uuid,
n.name AS source_name,
n.summary AS source_summary,
m.uuid AS target_node_uuid,
m.name AS target_name,
m.summary AS target_summary,
r.uuid AS uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
""",
node_ids=node_ids,
)
context = {}
for record in records:
n_uuid = record['source_node_uuid']
if n_uuid in context:
context[n_uuid]['facts'].append(record['fact'])
else:
context[n_uuid] = {
'name': record['source_name'],
'summary': record['source_summary'],
'facts': [record['fact']],
}
m_uuid = record['target_node_uuid']
if m_uuid not in context:
context[m_uuid] = {
'name': record['target_name'],
'summary': record['target_summary'],
'facts': [],
}
logger.info(f'bfs search returned context: {context}')
return context
async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
YIELD relationship AS r, score
MATCH (n)-[r:RELATES_TO]->(m)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""",
search_vector=search_vector,
limit=limit,
)
edges: list[EntityEdge] = []
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=safely_parse_db_date(record['created_at']),
expired_at=safely_parse_db_date(record['expired_at']),
valid_at=safely_parse_db_date(record['valid_at']),
invalid_At=safely_parse_db_date(record['invalid_at']),
)
edges.append(edge)
return edges
async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
RETURN
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
ORDER BY score DESC
""",
search_vector=search_vector,
limit=limit,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
labels=[],
created_at=safely_parse_db_date(record['created_at']),
summary=record['summary'],
)
)
return nodes
async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = query + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN
node.uuid As uuid,
node.name AS name,
node.created_at AS created_at,
node.summary AS summary
ORDER BY score DESC
LIMIT $limit
""",
query=fuzzy_query,
limit=limit,
)
nodes: list[EntityNode] = []
for record in records:
nodes.append(
EntityNode(
uuid=record['uuid'],
name=record['name'],
labels=[],
created_at=safely_parse_db_date(record['created_at']),
summary=record['summary'],
)
)
return nodes
async def edge_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = query + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score
MATCH (n:Entity)-[r]->(m:Entity)
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
r.fact_embedding AS fact_embedding,
r.episodes AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at
ORDER BY score DESC LIMIT $limit
""",
query=fuzzy_query,
limit=limit,
)
edges: list[EntityEdge] = []
for record in records:
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
name=record['name'],
episodes=record['episodes'],
fact_embedding=record['fact_embedding'],
created_at=safely_parse_db_date(record['created_at']),
expired_at=safely_parse_db_date(record['expired_at']),
valid_at=safely_parse_db_date(record['valid_at']),
invalid_At=safely_parse_db_date(record['invalid_at']),
)
edges.append(edge)
return edges
def safely_parse_db_date(date_str: neo4j_time.Date) -> datetime:
if date_str:
return datetime.fromisoformat(date_str.iso_format())
return None
async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
start = time()
relevant_nodes: list[EntityNode] = []
relevant_node_uuids = set()
results = await asyncio.gather(
*[entity_fulltext_search(node.name, driver) for node in nodes],
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
)
for result in results:
for node in result:
if node.uuid in relevant_node_uuids:
continue
relevant_node_uuids.add(node.uuid)
relevant_nodes.append(node)
end = time()
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
return relevant_nodes
async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver,
) -> list[EntityEdge]:
start = time()
relevant_edges: list[EntityEdge] = []
relevant_edge_uuids = set()
results = await asyncio.gather(
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
)
for result in results:
for edge in result:
if edge.uuid in relevant_edge_uuids:
continue
relevant_edge_uuids.add(edge.uuid)
relevant_edges.append(edge)
end = time()
logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
return relevant_edges

View File

@ -14,8 +14,8 @@ def build_episodic_edges(
for node in entity_nodes:
edges.append(
EpisodicEdge(
source_node_uuid=episode,
target_node_uuid=node,
source_node_uuid=episode.uuid,
target_node_uuid=node.uuid,
created_at=episode.created_at,
)
)

View File

@ -2,7 +2,10 @@
name = "graphiti"
version = "0.0.1"
description = "Graph building library"
authors = ["Paul Paliychuk <paul@getzep.com>", "Preston Rasmussen <preston@getzep.com>"]
authors = [
"Paul Paliychuk <paul@getzep.com>",
"Preston Rasmussen <preston@getzep.com>",
]
readme = "README.md"
[tool.poetry.dependencies]
@ -56,4 +59,4 @@ ignore = ["E501"]
[tool.ruff.format]
quote-style = "single"
indent-style = "tab"
docstring-code-format = true
docstring-code-format = true

View File

@ -103,11 +103,11 @@ async def test_graph_integration():
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episode, target_node_uuid=alice_node, created_at=now
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
)
episodic_edge_2 = EpisodicEdge(
source_node_uuid=episode, target_node_uuid=bob_node, created_at=now
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
)
entity_edge = EntityEdge(