mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
chore: Fix Typing Issues (#27)
* typing.Any and friends
* message
* chore: Import Message model in llm_client
* fix: 💄 mypy errors
* clean up mypy stuff
* mypy
* format
* mypy
* mypy
* mypy
---------
Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
Co-authored-by: prestonrasmussen <prasmuss15@gmail.com>
This commit is contained in:
parent
7152a211ae
commit
9cc9883e66
2
Makefile
2
Makefile
@ -22,7 +22,7 @@ format:
|
||||
# Lint code
|
||||
lint:
|
||||
$(RUFF) check
|
||||
$(MYPY) . --show-column-numbers --show-error-codes --pretty
|
||||
$(MYPY) ./core --show-column-numbers --show-error-codes --pretty
|
||||
|
||||
# Run tests
|
||||
test:
|
||||
|
@ -56,7 +56,7 @@ class Graphiti:
|
||||
else:
|
||||
self.llm_client = OpenAIClient(
|
||||
LLMConfig(
|
||||
api_key=os.getenv('OPENAI_API_KEY'),
|
||||
api_key=os.getenv('OPENAI_API_KEY', default=''),
|
||||
model='gpt-4o-mini',
|
||||
base_url='https://api.openai.com/v1',
|
||||
)
|
||||
@ -72,28 +72,16 @@ class Graphiti:
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
sources: list[str] | None = 'messages',
|
||||
) -> list[EpisodicNode]:
|
||||
"""Retrieve the last n episodic nodes from the graph"""
|
||||
return await retrieve_episodes(self.driver, reference_time, last_n, sources)
|
||||
|
||||
# Invalidate edges that are no longer valid
|
||||
async def invalidate_edges(
|
||||
self,
|
||||
episode: EpisodicNode,
|
||||
new_nodes: list[EntityNode],
|
||||
new_edges: list[EntityEdge],
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
): ...
|
||||
return await retrieve_episodes(self.driver, reference_time, last_n)
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
name: str,
|
||||
episode_body: str,
|
||||
source_description: str,
|
||||
reference_time: datetime | None = None,
|
||||
episode_type: str | None = 'string', # TODO: this field isn't used yet?
|
||||
reference_time: datetime,
|
||||
success_callback: Callable | None = None,
|
||||
error_callback: Callable | None = None,
|
||||
):
|
||||
@ -104,7 +92,7 @@ class Graphiti:
|
||||
nodes: list[EntityNode] = []
|
||||
entity_edges: list[EntityEdge] = []
|
||||
episodic_edges: list[EpisodicEdge] = []
|
||||
embedder = self.llm_client.client.embeddings
|
||||
embedder = self.llm_client.get_embedder()
|
||||
now = datetime.now()
|
||||
|
||||
previous_episodes = await self.retrieve_episodes(reference_time)
|
||||
@ -234,7 +222,7 @@ class Graphiti:
|
||||
):
|
||||
try:
|
||||
start = time()
|
||||
embedder = self.llm_client.client.embeddings
|
||||
embedder = self.llm_client.get_embedder()
|
||||
now = datetime.now()
|
||||
|
||||
episodes = [
|
||||
@ -276,14 +264,22 @@ class Graphiti:
|
||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
|
||||
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||
extracted_edges: list[EntityEdge] = resolve_edge_pointers(extracted_edges, uuid_map)
|
||||
episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(episodic_edges, uuid_map)
|
||||
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
||||
extracted_edges, uuid_map
|
||||
)
|
||||
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
|
||||
episodic_edges, uuid_map
|
||||
)
|
||||
|
||||
# save episodic edges to KG
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
||||
await asyncio.gather(
|
||||
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
|
||||
)
|
||||
|
||||
# Dedupe extracted edges
|
||||
edges = await dedupe_edges_bulk(self.driver, self.llm_client, extracted_edges)
|
||||
edges = await dedupe_edges_bulk(
|
||||
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
|
||||
)
|
||||
logger.info(f'extracted edge length: {len(edges)}')
|
||||
|
||||
# invalidate edges
|
||||
@ -302,12 +298,12 @@ class Graphiti:
|
||||
edges = (
|
||||
await hybrid_search(
|
||||
self.driver,
|
||||
self.llm_client.client.embeddings,
|
||||
self.llm_client.get_embedder(),
|
||||
query,
|
||||
datetime.now(),
|
||||
search_config,
|
||||
)
|
||||
)['edges']
|
||||
).edges
|
||||
|
||||
facts = [edge.fact for edge in edges]
|
||||
|
||||
@ -315,5 +311,5 @@ class Graphiti:
|
||||
|
||||
async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
|
||||
return await hybrid_search(
|
||||
self.driver, self.llm_client.client.embeddings, query, timestamp, config
|
||||
self.driver, self.llm_client.get_embedder(), query, timestamp, config
|
||||
)
|
||||
|
@ -1,5 +1,7 @@
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .config import LLMConfig
|
||||
|
||||
|
||||
@ -9,5 +11,9 @@ class LLMClient(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
|
||||
def get_embedder(self) -> typing.Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
pass
|
||||
|
@ -1,8 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
|
||||
@ -14,16 +17,26 @@ class OpenAIClient(LLMClient):
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
self.model = config.model
|
||||
|
||||
async def generate_response(self, messages: list[dict[str, str]]) -> dict[str, any]:
|
||||
def get_embedder(self) -> typing.Any:
|
||||
return self.client.embeddings
|
||||
|
||||
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
openai_messages: list[ChatCompletionMessageParam] = []
|
||||
for m in messages:
|
||||
if m.role == 'user':
|
||||
openai_messages.append({'role': 'user', 'content': m.content})
|
||||
elif m.role == 'system':
|
||||
openai_messages.append({'role': 'system', 'content': m.content})
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
messages=openai_messages,
|
||||
temperature=0.1,
|
||||
max_tokens=3000,
|
||||
response_format={'type': 'json_object'},
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
result = response.choices[0].message.content or ''
|
||||
return json.loads(result)
|
||||
except Exception as e:
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Protocol, TypedDict
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
from .models import Message, PromptFunction, PromptVersion
|
||||
|
||||
@ -7,6 +7,7 @@ from .models import Message, PromptFunction, PromptVersion
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
v2: PromptVersion
|
||||
edge_list: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
@ -15,7 +16,7 @@ class Versions(TypedDict):
|
||||
edge_list: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
def v1(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
@ -55,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||
]
|
||||
|
||||
|
||||
def v2(context: dict[str, any]) -> list[Message]:
|
||||
def v2(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
@ -97,7 +98,7 @@ def v2(context: dict[str, any]) -> list[Message]:
|
||||
]
|
||||
|
||||
|
||||
def edge_list(context: dict[str, any]) -> list[Message]:
|
||||
def edge_list(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Protocol, TypedDict
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
from .models import Message, PromptFunction, PromptVersion
|
||||
|
||||
@ -16,7 +16,7 @@ class Versions(TypedDict):
|
||||
node_list: PromptVersion
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
def v1(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
@ -56,7 +56,7 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||
]
|
||||
|
||||
|
||||
def v2(context: dict[str, any]) -> list[Message]:
|
||||
def v2(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
@ -96,7 +96,7 @@ def v2(context: dict[str, any]) -> list[Message]:
|
||||
]
|
||||
|
||||
|
||||
def node_list(context: dict[str, any]) -> list[Message]:
|
||||
def node_list(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Protocol, TypedDict
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
from .models import Message, PromptFunction, PromptVersion
|
||||
|
||||
@ -14,7 +14,7 @@ class Versions(TypedDict):
|
||||
v2: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
def v1(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
@ -70,7 +70,7 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||
]
|
||||
|
||||
|
||||
def v2(context: dict[str, any]) -> list[Message]:
|
||||
def v2(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Protocol, TypedDict
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
from .models import Message, PromptFunction, PromptVersion
|
||||
|
||||
@ -16,7 +16,7 @@ class Versions(TypedDict):
|
||||
v3: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
def v1(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
@ -64,7 +64,7 @@ def v1(context: dict[str, any]) -> list[Message]:
|
||||
]
|
||||
|
||||
|
||||
def v2(context: dict[str, any]) -> list[Message]:
|
||||
def v2(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
@ -105,7 +105,7 @@ def v2(context: dict[str, any]) -> list[Message]:
|
||||
]
|
||||
|
||||
|
||||
def v3(context: dict[str, any]) -> list[Message]:
|
||||
def v3(context: dict[str, Any]) -> list[Message]:
|
||||
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
|
||||
|
||||
user_prompt = f"""
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Protocol, TypedDict
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
from .models import Message, PromptFunction, PromptVersion
|
||||
|
||||
@ -11,7 +11,7 @@ class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
def v1(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Protocol, TypedDict
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
from .dedupe_edges import (
|
||||
Prompt as DedupeEdgesPrompt,
|
||||
@ -68,7 +68,7 @@ class VersionWrapper:
|
||||
def __init__(self, func: PromptFunction):
|
||||
self.func = func
|
||||
|
||||
def __call__(self, context: dict[str, any]) -> list[Message]:
|
||||
def __call__(self, context: dict[str, Any]) -> list[Message]:
|
||||
return self.func(context)
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ class PromptTypeWrapper:
|
||||
class PromptLibraryWrapper:
|
||||
def __init__(self, library: PromptLibraryImpl):
|
||||
for prompt_type, versions in library.items():
|
||||
setattr(self, prompt_type, PromptTypeWrapper(versions))
|
||||
setattr(self, prompt_type, PromptTypeWrapper(versions)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
||||
@ -91,5 +91,4 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
||||
'dedupe_edges': dedupe_edges_versions,
|
||||
'invalidate_edges': invalidate_edges_versions,
|
||||
}
|
||||
|
||||
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)
|
||||
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Protocol
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -9,7 +9,7 @@ class Message(BaseModel):
|
||||
|
||||
|
||||
class PromptVersion(Protocol):
|
||||
def __call__(self, context: dict[str, any]) -> list[Message]: ...
|
||||
def __call__(self, context: dict[str, Any]) -> list[Message]: ...
|
||||
|
||||
|
||||
PromptFunction = Callable[[dict[str, any]], list[Message]]
|
||||
PromptFunction = Callable[[dict[str, Any]], list[Message]]
|
||||
|
@ -5,9 +5,9 @@ from time import time
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.edges import Edge
|
||||
from core.edges import EntityEdge
|
||||
from core.llm_client.config import EMBEDDING_DIM
|
||||
from core.nodes import Node
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
from core.search.search_utils import (
|
||||
edge_fulltext_search,
|
||||
edge_similarity_search,
|
||||
@ -28,9 +28,15 @@ class SearchConfig(BaseModel):
|
||||
reranker: str = 'rrf'
|
||||
|
||||
|
||||
class SearchResults(BaseModel):
|
||||
episodes: list[EpisodicNode]
|
||||
nodes: list[EntityNode]
|
||||
edges: list[EntityEdge]
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
|
||||
) -> dict[str, [Node | Edge]]:
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
|
||||
episodes = []
|
||||
@ -86,11 +92,7 @@ async def hybrid_search(
|
||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
edges.extend(reranked_edges)
|
||||
|
||||
context = {
|
||||
'episodes': episodes,
|
||||
'nodes': nodes,
|
||||
'edges': edges,
|
||||
}
|
||||
context = SearchResults(episodes=episodes, nodes=nodes, edges=edges)
|
||||
|
||||
end = time()
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
|
||||
|
||||
def parse_db_date(neo_date: neo4j_time.Date | None) -> datetime | None:
|
||||
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
||||
return neo_date.to_native() if neo_date else None
|
||||
|
||||
|
||||
@ -41,7 +42,7 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
labels=['Entity'],
|
||||
created_at=datetime.now(),
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
@ -74,7 +75,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||
node_ids=node_ids,
|
||||
)
|
||||
|
||||
context = {}
|
||||
context: dict[str, typing.Any] = {}
|
||||
|
||||
for record in records:
|
||||
n_uuid = record['source_node_uuid']
|
||||
@ -173,7 +174,7 @@ async def entity_similarity_search(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
labels=['Entity'],
|
||||
created_at=datetime.now(),
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
@ -208,7 +209,7 @@ async def entity_fulltext_search(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
labels=['Entity'],
|
||||
created_at=datetime.now(),
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
@ -277,7 +278,11 @@ async def get_relevant_nodes(
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[entity_fulltext_search(node.name, driver) for node in nodes],
|
||||
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
|
||||
*[
|
||||
entity_similarity_search(node.name_embedding, driver)
|
||||
for node in nodes
|
||||
if node.name_embedding is not None
|
||||
],
|
||||
)
|
||||
|
||||
for result in results:
|
||||
@ -303,7 +308,11 @@ async def get_relevant_edges(
|
||||
relevant_edge_uuids = set()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
|
||||
*[
|
||||
edge_similarity_search(edge.fact_embedding, driver)
|
||||
for edge in edges
|
||||
if edge.fact_embedding is not None
|
||||
],
|
||||
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
|
||||
)
|
||||
|
||||
|
@ -1,15 +1,15 @@
|
||||
from .maintenance import (
|
||||
build_episodic_edges,
|
||||
clear_data,
|
||||
extract_new_edges,
|
||||
extract_new_nodes,
|
||||
extract_edges,
|
||||
extract_nodes,
|
||||
retrieve_episodes,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'extract_new_edges',
|
||||
'extract_edges',
|
||||
'build_episodic_edges',
|
||||
'extract_new_nodes',
|
||||
'extract_nodes',
|
||||
'clear_data',
|
||||
'retrieve_episodes',
|
||||
]
|
||||
|
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import typing
|
||||
from datetime import datetime
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
@ -121,8 +122,8 @@ async def dedupe_edges_bulk(
|
||||
|
||||
|
||||
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
uuid_map = {}
|
||||
name_map = {}
|
||||
uuid_map: dict[str, str] = {}
|
||||
name_map: dict[str, EntityNode] = {}
|
||||
for node in nodes:
|
||||
if node.name in name_map:
|
||||
uuid_map[node.uuid] = name_map[node.name].uuid
|
||||
@ -182,7 +183,10 @@ def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
||||
return compressed_map
|
||||
|
||||
|
||||
def resolve_edge_pointers(edges: list[Edge], uuid_map: dict[str, str]):
|
||||
E = typing.TypeVar('E', bound=Edge)
|
||||
|
||||
|
||||
def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
||||
for edge in edges:
|
||||
source_uuid = edge.source_node_uuid
|
||||
target_uuid = edge.target_node_uuid
|
||||
|
@ -1,15 +1,15 @@
|
||||
from .edge_operations import build_episodic_edges, extract_new_edges
|
||||
from .edge_operations import build_episodic_edges, extract_edges
|
||||
from .graph_data_operations import (
|
||||
clear_data,
|
||||
retrieve_episodes,
|
||||
)
|
||||
from .node_operations import extract_new_nodes
|
||||
from .node_operations import extract_nodes
|
||||
from .temporal_operations import invalidate_edges
|
||||
|
||||
__all__ = [
|
||||
'extract_new_edges',
|
||||
'extract_edges',
|
||||
'build_episodic_edges',
|
||||
'extract_new_nodes',
|
||||
'extract_nodes',
|
||||
'clear_data',
|
||||
'retrieve_episodes',
|
||||
'invalidate_edges',
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
@ -8,7 +7,6 @@ from core.edges import EntityEdge, EpisodicEdge
|
||||
from core.llm_client import LLMClient
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
from core.prompts import prompt_library
|
||||
from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -31,103 +29,6 @@ def build_episodic_edges(
|
||||
return edges
|
||||
|
||||
|
||||
async def extract_new_edges(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
new_nodes: list[EntityNode],
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> tuple[list[EntityEdge], list[EntityNode]]:
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
'episode_content': episode.content,
|
||||
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
|
||||
'relevant_schema': json.dumps(relevant_schema, indent=2),
|
||||
'new_nodes': [{'name': node.name, 'summary': node.summary} for node in new_nodes],
|
||||
'previous_episodes': [
|
||||
{
|
||||
'content': ep.content,
|
||||
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
|
||||
}
|
||||
for ep in previous_episodes
|
||||
],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.extract_edges.v1(context))
|
||||
new_edges_data = llm_response.get('new_edges', [])
|
||||
logger.info(f'Extracted new edges: {new_edges_data}')
|
||||
|
||||
# Convert the extracted data into EntityEdge objects
|
||||
new_edges = []
|
||||
for edge_data in new_edges_data:
|
||||
source_node = next(
|
||||
(node for node in new_nodes if node.name == edge_data['source_node']),
|
||||
None,
|
||||
)
|
||||
target_node = next(
|
||||
(node for node in new_nodes if node.name == edge_data['target_node']),
|
||||
None,
|
||||
)
|
||||
|
||||
# If source or target is not in new_nodes, check if it's an existing node
|
||||
if source_node is None and edge_data['source_node'] in relevant_schema['nodes']:
|
||||
existing_node_data = relevant_schema['nodes'][edge_data['source_node']]
|
||||
source_node = EntityNode(
|
||||
uuid=existing_node_data['uuid'],
|
||||
name=edge_data['source_node'],
|
||||
labels=[existing_node_data['label']],
|
||||
summary='',
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
if target_node is None and edge_data['target_node'] in relevant_schema['nodes']:
|
||||
existing_node_data = relevant_schema['nodes'][edge_data['target_node']]
|
||||
target_node = EntityNode(
|
||||
uuid=existing_node_data['uuid'],
|
||||
name=edge_data['target_node'],
|
||||
labels=[existing_node_data['label']],
|
||||
summary='',
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
if (
|
||||
source_node
|
||||
and target_node
|
||||
and not (
|
||||
source_node.name.startswith('Message') or target_node.name.startswith('Message')
|
||||
)
|
||||
):
|
||||
valid_at = (
|
||||
datetime.fromisoformat(edge_data['valid_at'])
|
||||
if edge_data['valid_at']
|
||||
else episode.valid_at or datetime.now()
|
||||
)
|
||||
invalid_at = (
|
||||
datetime.fromisoformat(edge_data['invalid_at']) if edge_data['invalid_at'] else None
|
||||
)
|
||||
|
||||
new_edge = EntityEdge(
|
||||
source_node=source_node,
|
||||
target_node=target_node,
|
||||
name=edge_data['relation_type'],
|
||||
fact=edge_data['fact'],
|
||||
episodes=[episode.uuid],
|
||||
created_at=datetime.now(),
|
||||
valid_at=valid_at,
|
||||
invalid_at=invalid_at,
|
||||
)
|
||||
new_edges.append(new_edge)
|
||||
logger.info(
|
||||
f'Created new edge: {new_edge.name} from {source_node.name} (UUID: {source_node.uuid}) to {target_node.name} (UUID: {target_node.uuid})'
|
||||
)
|
||||
|
||||
affected_nodes = set()
|
||||
|
||||
for edge in new_edges:
|
||||
affected_nodes.add(edge.source_node)
|
||||
affected_nodes.add(edge.target_node)
|
||||
return new_edges, list(affected_nodes)
|
||||
|
||||
|
||||
async def extract_edges(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
@ -186,45 +87,6 @@ def create_edge_identifier(
|
||||
return f'{source_node.name}-{edge.name}-{target_node.name}'
|
||||
|
||||
|
||||
async def dedupe_extracted_edges_v2(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[NodeEdgeNodeTriplet],
|
||||
existing_edges: list[NodeEdgeNodeTriplet],
|
||||
) -> list[NodeEdgeNodeTriplet]:
|
||||
# Create edge map
|
||||
edge_map = {}
|
||||
for n1, edge, n2 in existing_edges:
|
||||
edge_map[create_edge_identifier(n1, edge, n2)] = edge
|
||||
for n1, edge, n2 in extracted_edges:
|
||||
if create_edge_identifier(n1, edge, n2) in edge_map:
|
||||
continue
|
||||
edge_map[create_edge_identifier(n1, edge, n2)] = edge
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
'extracted_edges': [
|
||||
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
|
||||
for n1, edge, n2 in extracted_edges
|
||||
],
|
||||
'existing_edges': [
|
||||
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
|
||||
for n1, edge, n2 in extracted_edges
|
||||
],
|
||||
}
|
||||
logger.info(prompt_library.dedupe_edges.v2(context))
|
||||
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v2(context))
|
||||
new_edges_data = llm_response.get('new_edges', [])
|
||||
logger.info(f'Extracted new edges: {new_edges_data}')
|
||||
|
||||
# Get full edge data
|
||||
edges = []
|
||||
for edge_data in new_edges_data:
|
||||
edge = edge_map[edge_data['triplet']]
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def dedupe_extracted_edges(
|
||||
llm_client: LLMClient,
|
||||
extracted_edges: list[EntityEdge],
|
||||
|
@ -52,9 +52,7 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
||||
}}
|
||||
""",
|
||||
]
|
||||
index_queries: list[LiteralString] = (
|
||||
range_indices + fulltext_indices + vector_indices
|
||||
)
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
|
||||
|
||||
await asyncio.gather(*[driver.execute_query(query) for query in index_queries])
|
||||
|
||||
@ -72,7 +70,6 @@ async def retrieve_episodes(
|
||||
driver: AsyncDriver,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
sources: list[str] | None = 'messages',
|
||||
) -> list[EpisodicNode]:
|
||||
"""Retrieve the last n episodic nodes from the graph"""
|
||||
result = await driver.execute_query(
|
||||
@ -97,14 +94,7 @@ async def retrieve_episodes(
|
||||
created_at=datetime.fromtimestamp(
|
||||
record['created_at'].to_native().timestamp(), timezone.utc
|
||||
),
|
||||
valid_at=(
|
||||
datetime.fromtimestamp(
|
||||
record['valid_at'].to_native().timestamp(),
|
||||
timezone.utc,
|
||||
)
|
||||
if record['valid_at'] is not None
|
||||
else None
|
||||
),
|
||||
valid_at=(record['valid_at'].to_native()),
|
||||
uuid=record['uuid'],
|
||||
source=record['source'],
|
||||
name=record['name'],
|
||||
|
@ -9,53 +9,6 @@ from core.prompts import prompt_library
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def extract_new_nodes(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityNode]:
|
||||
# Prepare context for LLM
|
||||
existing_nodes = [
|
||||
{'name': node_name, 'label': node_info['label'], 'uuid': node_info['uuid']}
|
||||
for node_name, node_info in relevant_schema['nodes'].items()
|
||||
]
|
||||
|
||||
context = {
|
||||
'episode_content': episode.content,
|
||||
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
|
||||
'existing_nodes': existing_nodes,
|
||||
'previous_episodes': [
|
||||
{
|
||||
'content': ep.content,
|
||||
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
|
||||
}
|
||||
for ep in previous_episodes
|
||||
],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v1(context))
|
||||
new_nodes_data = llm_response.get('new_nodes', [])
|
||||
logger.info(f'Extracted new nodes: {new_nodes_data}')
|
||||
# Convert the extracted data into EntityNode objects
|
||||
new_nodes = []
|
||||
for node_data in new_nodes_data:
|
||||
# Check if the node already exists
|
||||
if not any(existing_node['name'] == node_data['name'] for existing_node in existing_nodes):
|
||||
new_node = EntityNode(
|
||||
name=node_data['name'],
|
||||
labels=node_data['labels'],
|
||||
summary=node_data['summary'],
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
new_nodes.append(new_node)
|
||||
logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||
else:
|
||||
logger.info(f"Node {node_data['name']} already exists, skipping creation.")
|
||||
|
||||
return new_nodes
|
||||
|
||||
|
||||
async def extract_nodes(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
@ -100,16 +53,16 @@ async def dedupe_extracted_nodes(
|
||||
llm_client: LLMClient,
|
||||
extracted_nodes: list[EntityNode],
|
||||
existing_nodes: list[EntityNode],
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]:
|
||||
start = time()
|
||||
|
||||
# build existing node map
|
||||
node_map = {}
|
||||
node_map: dict[str, EntityNode] = {}
|
||||
for node in existing_nodes:
|
||||
node_map[node.name] = node
|
||||
|
||||
# Temp hack
|
||||
new_nodes_map = {}
|
||||
new_nodes_map: dict[str, EntityNode] = {}
|
||||
for node in extracted_nodes:
|
||||
new_nodes_map[node.name] = node
|
||||
|
||||
@ -134,14 +87,14 @@ async def dedupe_extracted_nodes(
|
||||
end = time()
|
||||
logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms')
|
||||
|
||||
uuid_map = {}
|
||||
uuid_map: dict[str, str] = {}
|
||||
for duplicate in duplicate_data:
|
||||
uuid = new_nodes_map[duplicate['name']].uuid
|
||||
uuid_value = node_map[duplicate['duplicate_of']].uuid
|
||||
uuid_map[uuid] = uuid_value
|
||||
|
||||
nodes = []
|
||||
brand_new_nodes = []
|
||||
nodes: list[EntityNode] = []
|
||||
brand_new_nodes: list[EntityNode] = []
|
||||
for node in extracted_nodes:
|
||||
if node.uuid in uuid_map:
|
||||
existing_uuid = uuid_map[node.uuid]
|
||||
@ -149,7 +102,9 @@ async def dedupe_extracted_nodes(
|
||||
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
|
||||
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
|
||||
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
|
||||
nodes.append(existing_node)
|
||||
if existing_node:
|
||||
nodes.append(existing_node)
|
||||
|
||||
continue
|
||||
brand_new_nodes.append(node)
|
||||
nodes.append(node)
|
||||
|
@ -23,6 +23,8 @@ def extract_node_edge_node_triplet(
|
||||
) -> NodeEdgeNodeTriplet:
|
||||
source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None)
|
||||
target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None)
|
||||
if not source_node or not target_node:
|
||||
raise ValueError(f'Source or target node not found for edge {edge.uuid}')
|
||||
return (source_node, edge, target_node)
|
||||
|
||||
|
||||
@ -31,11 +33,8 @@ def prepare_edges_for_invalidation(
|
||||
new_edges: list[EntityEdge],
|
||||
nodes: list[EntityNode],
|
||||
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
|
||||
existing_edges_pending_invalidation = [] # TODO: this is not yet used?
|
||||
new_edges_with_nodes = [] # TODO: this is not yet used?
|
||||
|
||||
existing_edges_pending_invalidation = []
|
||||
new_edges_with_nodes = []
|
||||
existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = []
|
||||
new_edges_with_nodes: list[NodeEdgeNodeTriplet] = []
|
||||
|
||||
for edge_list, result_list in [
|
||||
(existing_edges, existing_edges_pending_invalidation),
|
||||
|
@ -1,292 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from neo4j import time as neo4j_time
|
||||
|
||||
from core.edges import EntityEdge
|
||||
from core.nodes import EntityNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
|
||||
|
||||
async def bfs(node_ids: list[str], driver: AsyncDriver):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
|
||||
RETURN
|
||||
n.uuid AS source_node_uuid,
|
||||
n.name AS source_name,
|
||||
n.summary AS source_summary,
|
||||
m.uuid AS target_node_uuid,
|
||||
m.name AS target_name,
|
||||
m.summary AS target_summary,
|
||||
r.uuid AS uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
|
||||
""",
|
||||
node_ids=node_ids,
|
||||
)
|
||||
|
||||
context = {}
|
||||
|
||||
for record in records:
|
||||
n_uuid = record['source_node_uuid']
|
||||
if n_uuid in context:
|
||||
context[n_uuid]['facts'].append(record['fact'])
|
||||
else:
|
||||
context[n_uuid] = {
|
||||
'name': record['source_name'],
|
||||
'summary': record['source_summary'],
|
||||
'facts': [record['fact']],
|
||||
}
|
||||
|
||||
m_uuid = record['target_node_uuid']
|
||||
if m_uuid not in context:
|
||||
context[m_uuid] = {
|
||||
'name': record['target_name'],
|
||||
'summary': record['target_summary'],
|
||||
'facts': [],
|
||||
}
|
||||
logger.info(f'bfs search returned context: {context}')
|
||||
return context
|
||||
|
||||
|
||||
async def edge_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
|
||||
YIELD relationship AS r, score
|
||||
MATCH (n)-[r:RELATES_TO]->(m)
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
for record in records:
|
||||
edge = EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
fact=record['fact'],
|
||||
name=record['name'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=safely_parse_db_date(record['created_at']),
|
||||
expired_at=safely_parse_db_date(record['expired_at']),
|
||||
valid_at=safely_parse_db_date(record['valid_at']),
|
||||
invalid_At=safely_parse_db_date(record['invalid_at']),
|
||||
)
|
||||
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def entity_similarity_search(
|
||||
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||
YIELD node AS n, score
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
ORDER BY score DESC
|
||||
""",
|
||||
search_vector=search_vector,
|
||||
limit=limit,
|
||||
)
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
labels=[],
|
||||
created_at=safely_parse_db_date(record['created_at']),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def entity_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = query + '~'
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
||||
RETURN
|
||||
node.uuid As uuid,
|
||||
node.name AS name,
|
||||
node.created_at AS created_at,
|
||||
node.summary AS summary
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
limit=limit,
|
||||
)
|
||||
nodes: list[EntityNode] = []
|
||||
|
||||
for record in records:
|
||||
nodes.append(
|
||||
EntityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
labels=[],
|
||||
created_at=safely_parse_db_date(record['created_at']),
|
||||
summary=record['summary'],
|
||||
)
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
async def edge_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
fuzzy_query = query + '~'
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS r, score
|
||||
MATCH (n:Entity)-[r]->(m:Entity)
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
r.created_at AS created_at,
|
||||
r.name AS name,
|
||||
r.fact AS fact,
|
||||
r.fact_embedding AS fact_embedding,
|
||||
r.episodes AS episodes,
|
||||
r.expired_at AS expired_at,
|
||||
r.valid_at AS valid_at,
|
||||
r.invalid_at AS invalid_at
|
||||
ORDER BY score DESC LIMIT $limit
|
||||
""",
|
||||
query=fuzzy_query,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = []
|
||||
|
||||
for record in records:
|
||||
edge = EntityEdge(
|
||||
uuid=record['uuid'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
fact=record['fact'],
|
||||
name=record['name'],
|
||||
episodes=record['episodes'],
|
||||
fact_embedding=record['fact_embedding'],
|
||||
created_at=safely_parse_db_date(record['created_at']),
|
||||
expired_at=safely_parse_db_date(record['expired_at']),
|
||||
valid_at=safely_parse_db_date(record['valid_at']),
|
||||
invalid_At=safely_parse_db_date(record['invalid_at']),
|
||||
)
|
||||
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def safely_parse_db_date(date_str: neo4j_time.Date) -> datetime:
|
||||
if date_str:
|
||||
return datetime.fromisoformat(date_str.iso_format())
|
||||
return None
|
||||
|
||||
|
||||
async def get_relevant_nodes(
|
||||
nodes: list[EntityNode],
|
||||
driver: AsyncDriver,
|
||||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
relevant_nodes: list[EntityNode] = []
|
||||
relevant_node_uuids = set()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[entity_fulltext_search(node.name, driver) for node in nodes],
|
||||
*[entity_similarity_search(node.name_embedding, driver) for node in nodes],
|
||||
)
|
||||
|
||||
for result in results:
|
||||
for node in result:
|
||||
if node.uuid in relevant_node_uuids:
|
||||
continue
|
||||
|
||||
relevant_node_uuids.add(node.uuid)
|
||||
relevant_nodes.append(node)
|
||||
|
||||
end = time()
|
||||
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
|
||||
|
||||
return relevant_nodes
|
||||
|
||||
|
||||
async def get_relevant_edges(
|
||||
edges: list[EntityEdge],
|
||||
driver: AsyncDriver,
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
relevant_edges: list[EntityEdge] = []
|
||||
relevant_edge_uuids = set()
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[edge_similarity_search(edge.fact_embedding, driver) for edge in edges],
|
||||
*[edge_fulltext_search(edge.fact, driver) for edge in edges],
|
||||
)
|
||||
|
||||
for result in results:
|
||||
for edge in result:
|
||||
if edge.uuid in relevant_edge_uuids:
|
||||
continue
|
||||
|
||||
relevant_edge_uuids.add(edge.uuid)
|
||||
relevant_edges.append(edge)
|
||||
|
||||
end = time()
|
||||
logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
|
||||
|
||||
return relevant_edges
|
@ -14,8 +14,8 @@ def build_episodic_edges(
|
||||
for node in entity_nodes:
|
||||
edges.append(
|
||||
EpisodicEdge(
|
||||
source_node_uuid=episode,
|
||||
target_node_uuid=node,
|
||||
source_node_uuid=episode.uuid,
|
||||
target_node_uuid=node.uuid,
|
||||
created_at=episode.created_at,
|
||||
)
|
||||
)
|
||||
|
@ -2,7 +2,10 @@
|
||||
name = "graphiti"
|
||||
version = "0.0.1"
|
||||
description = "Graph building library"
|
||||
authors = ["Paul Paliychuk <paul@getzep.com>", "Preston Rasmussen <preston@getzep.com>"]
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
"Preston Rasmussen <preston@getzep.com>",
|
||||
]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
@ -56,4 +59,4 @@ ignore = ["E501"]
|
||||
[tool.ruff.format]
|
||||
quote-style = "single"
|
||||
indent-style = "tab"
|
||||
docstring-code-format = true
|
||||
docstring-code-format = true
|
||||
|
@ -103,11 +103,11 @@ async def test_graph_integration():
|
||||
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
|
||||
|
||||
episodic_edge_1 = EpisodicEdge(
|
||||
source_node_uuid=episode, target_node_uuid=alice_node, created_at=now
|
||||
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
|
||||
)
|
||||
|
||||
episodic_edge_2 = EpisodicEdge(
|
||||
source_node_uuid=episode, target_node_uuid=bob_node, created_at=now
|
||||
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
|
||||
)
|
||||
|
||||
entity_edge = EntityEdge(
|
||||
|
Loading…
x
Reference in New Issue
Block a user