2024-08-13 14:35:43 -04:00
|
|
|
import asyncio
|
|
|
|
from datetime import datetime
|
|
|
|
import logging
|
2024-08-18 13:22:31 -04:00
|
|
|
from typing import Callable, LiteralString
|
2024-08-13 14:35:43 -04:00
|
|
|
from neo4j import AsyncGraphDatabase
|
2024-08-15 12:03:41 -04:00
|
|
|
from dotenv import load_dotenv
|
|
|
|
import os
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
from core.llm_client.config import EMBEDDING_DIM
|
2024-08-15 11:04:57 -04:00
|
|
|
from core.nodes import EntityNode, EpisodicNode, Node
|
2024-08-20 16:29:19 -04:00
|
|
|
from core.edges import EntityEdge, EpisodicEdge
|
2024-08-15 12:03:41 -04:00
|
|
|
from core.utils import (
|
|
|
|
build_episodic_edges,
|
|
|
|
retrieve_episodes,
|
|
|
|
)
|
|
|
|
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
|
2024-08-20 16:29:19 -04:00
|
|
|
from core.utils.maintenance.edge_operations import (
|
|
|
|
extract_edges,
|
|
|
|
dedupe_extracted_edges,
|
|
|
|
)
|
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
2024-08-20 16:29:19 -04:00
|
|
|
from core.utils.maintenance.temporal_operations import (
|
|
|
|
prepare_edges_for_invalidation,
|
|
|
|
invalidate_edges,
|
|
|
|
)
|
2024-08-18 13:22:31 -04:00
|
|
|
from core.utils.search.search_utils import (
|
|
|
|
edge_similarity_search,
|
|
|
|
entity_fulltext_search,
|
|
|
|
bfs,
|
|
|
|
get_relevant_nodes,
|
|
|
|
get_relevant_edges,
|
|
|
|
)
|
2024-08-13 14:35:43 -04:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2024-08-15 12:03:41 -04:00
|
|
|
load_dotenv()
|
2024-08-14 10:17:12 -04:00
|
|
|
|
|
|
|
|
2024-08-13 14:35:43 -04:00
|
|
|
class Graphiti:
|
2024-08-14 10:17:12 -04:00
|
|
|
def __init__(
|
2024-08-15 12:03:41 -04:00
|
|
|
self, uri: str, user: str, password: str, llm_client: LLMClient | None = None
|
2024-08-14 10:17:12 -04:00
|
|
|
):
|
2024-08-13 14:35:43 -04:00
|
|
|
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
|
|
|
self.database = "neo4j"
|
2024-08-15 12:03:41 -04:00
|
|
|
if llm_client:
|
|
|
|
self.llm_client = llm_client
|
2024-08-14 10:17:12 -04:00
|
|
|
else:
|
2024-08-15 12:03:41 -04:00
|
|
|
self.llm_client = OpenAIClient(
|
|
|
|
LLMConfig(
|
|
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
2024-08-18 13:22:31 -04:00
|
|
|
model="gpt-4o-mini",
|
2024-08-15 12:03:41 -04:00
|
|
|
base_url="https://api.openai.com/v1",
|
|
|
|
)
|
|
|
|
)
|
2024-08-13 14:35:43 -04:00
|
|
|
|
|
|
|
def close(self):
|
2024-08-14 10:17:12 -04:00
|
|
|
self.driver.close()
|
|
|
|
|
2024-08-15 12:03:41 -04:00
|
|
|
async def retrieve_episodes(
|
|
|
|
self, last_n: int, sources: list[str] | None = "messages"
|
|
|
|
) -> list[EpisodicNode]:
|
|
|
|
"""Retrieve the last n episodic nodes from the graph"""
|
|
|
|
return await retrieve_episodes(self.driver, last_n, sources)
|
|
|
|
|
|
|
|
async def add_episode(
|
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
episode_body: str,
|
|
|
|
source_description: str,
|
|
|
|
reference_time: datetime = None,
|
|
|
|
episode_type="string",
|
|
|
|
success_callback: Callable | None = None,
|
|
|
|
error_callback: Callable | None = None,
|
|
|
|
):
|
|
|
|
"""Process an episode and update the graph"""
|
|
|
|
try:
|
2024-08-18 13:22:31 -04:00
|
|
|
nodes: list[EntityNode] = []
|
|
|
|
entity_edges: list[EntityEdge] = []
|
|
|
|
episodic_edges: list[EpisodicEdge] = []
|
|
|
|
embedder = self.llm_client.client.embeddings
|
|
|
|
now = datetime.now()
|
|
|
|
|
2024-08-15 12:03:41 -04:00
|
|
|
previous_episodes = await self.retrieve_episodes(last_n=3)
|
|
|
|
episode = EpisodicNode(
|
|
|
|
name=name,
|
|
|
|
labels=[],
|
|
|
|
source="messages",
|
|
|
|
content=episode_body,
|
|
|
|
source_description=source_description,
|
2024-08-18 13:22:31 -04:00
|
|
|
created_at=now,
|
2024-08-15 12:03:41 -04:00
|
|
|
valid_at=reference_time,
|
|
|
|
)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
extracted_nodes = await extract_nodes(
|
|
|
|
self.llm_client, episode, previous_episodes
|
|
|
|
)
|
|
|
|
|
|
|
|
# Calculate Embeddings
|
|
|
|
|
|
|
|
await asyncio.gather(
|
|
|
|
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
|
|
|
)
|
|
|
|
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
|
2024-08-19 09:37:56 -04:00
|
|
|
logger.info(
|
|
|
|
f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}"
|
|
|
|
)
|
2024-08-18 13:22:31 -04:00
|
|
|
new_nodes = await dedupe_extracted_nodes(
|
|
|
|
self.llm_client, extracted_nodes, existing_nodes
|
2024-08-15 12:03:41 -04:00
|
|
|
)
|
2024-08-19 09:37:56 -04:00
|
|
|
logger.info(
|
|
|
|
f"Deduped touched nodes: {[(n.name, n.uuid) for n in new_nodes]}"
|
|
|
|
)
|
2024-08-15 12:03:41 -04:00
|
|
|
nodes.extend(new_nodes)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
extracted_edges = await extract_edges(
|
|
|
|
self.llm_client, episode, new_nodes, previous_episodes
|
2024-08-15 12:03:41 -04:00
|
|
|
)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
await asyncio.gather(
|
|
|
|
*[edge.generate_embedding(embedder) for edge in extracted_edges]
|
|
|
|
)
|
|
|
|
|
|
|
|
existing_edges = await get_relevant_edges(extracted_edges, self.driver)
|
2024-08-19 09:37:56 -04:00
|
|
|
logger.info(f"Existing edges: {[(e.name, e.uuid) for e in existing_edges]}")
|
|
|
|
logger.info(
|
|
|
|
f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}"
|
|
|
|
)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-20 16:29:19 -04:00
|
|
|
deduped_edges = await dedupe_extracted_edges(
|
2024-08-18 13:22:31 -04:00
|
|
|
self.llm_client, extracted_edges, existing_edges
|
|
|
|
)
|
|
|
|
|
2024-08-20 16:29:19 -04:00
|
|
|
(
|
|
|
|
old_edges_with_nodes_pending_invalidation,
|
|
|
|
new_edges_with_nodes,
|
|
|
|
) = prepare_edges_for_invalidation(
|
|
|
|
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
|
|
|
|
)
|
|
|
|
|
|
|
|
invalidated_edges = await invalidate_edges(
|
|
|
|
self.llm_client,
|
|
|
|
old_edges_with_nodes_pending_invalidation,
|
|
|
|
new_edges_with_nodes,
|
|
|
|
)
|
|
|
|
|
|
|
|
entity_edges.extend(invalidated_edges)
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}"
|
|
|
|
)
|
|
|
|
|
|
|
|
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")
|
2024-08-19 09:37:56 -04:00
|
|
|
|
2024-08-20 16:29:19 -04:00
|
|
|
entity_edges.extend(deduped_edges)
|
2024-08-18 13:22:31 -04:00
|
|
|
episodic_edges.extend(
|
|
|
|
build_episodic_edges(
|
|
|
|
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
|
|
|
nodes,
|
|
|
|
episode,
|
|
|
|
now,
|
|
|
|
)
|
2024-08-16 09:29:57 -04:00
|
|
|
)
|
|
|
|
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
|
|
|
|
logger.info(f"Built episodic edges: {episodic_edges}")
|
2024-08-15 12:03:41 -04:00
|
|
|
|
|
|
|
# invalidated_edges = await self.invalidate_edges(
|
|
|
|
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
|
|
|
|
# )
|
|
|
|
|
|
|
|
# edges.extend(invalidated_edges)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-15 12:03:41 -04:00
|
|
|
# Future optimization would be using batch operations to save nodes and edges
|
2024-08-18 13:22:31 -04:00
|
|
|
await episode.save(self.driver)
|
2024-08-15 12:03:41 -04:00
|
|
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
2024-08-18 13:22:31 -04:00
|
|
|
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
|
|
|
|
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
2024-08-15 12:03:41 -04:00
|
|
|
# for node in nodes:
|
|
|
|
# if isinstance(node, EntityNode):
|
|
|
|
# await node.update_summary(self.driver)
|
|
|
|
if success_callback:
|
|
|
|
await success_callback(episode)
|
|
|
|
except Exception as e:
|
|
|
|
if error_callback:
|
|
|
|
await error_callback(episode, e)
|
|
|
|
else:
|
|
|
|
raise e
|
|
|
|
|
2024-08-15 11:04:57 -04:00
|
|
|
async def build_indices(self):
|
|
|
|
index_queries: list[LiteralString] = [
|
2024-08-18 13:22:31 -04:00
|
|
|
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
|
|
|
|
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
|
|
|
|
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)",
|
|
|
|
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)",
|
2024-08-15 11:04:57 -04:00
|
|
|
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
|
|
|
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
|
|
|
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
|
|
|
|
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
|
|
|
|
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)",
|
|
|
|
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)",
|
|
|
|
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)",
|
|
|
|
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)",
|
|
|
|
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)",
|
|
|
|
]
|
|
|
|
# Add the range indices
|
|
|
|
for query in index_queries:
|
|
|
|
await self.driver.execute_query(query)
|
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
# Add the semantic indices
|
2024-08-15 11:04:57 -04:00
|
|
|
await self.driver.execute_query(
|
|
|
|
"""
|
|
|
|
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
|
|
|
|
"""
|
|
|
|
)
|
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
await self.driver.execute_query(
|
|
|
|
"""
|
|
|
|
CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact]
|
|
|
|
"""
|
|
|
|
)
|
|
|
|
|
2024-08-15 11:04:57 -04:00
|
|
|
await self.driver.execute_query(
|
|
|
|
"""
|
|
|
|
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
|
|
|
|
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
|
|
|
|
OPTIONS {indexConfig: {
|
|
|
|
`vector.dimensions`: 1024,
|
|
|
|
`vector.similarity_function`: 'cosine'
|
|
|
|
}}
|
|
|
|
"""
|
|
|
|
)
|
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
await self.driver.execute_query(
|
|
|
|
"""
|
|
|
|
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
|
|
|
|
FOR (n:Entity) ON (n.name_embedding)
|
|
|
|
OPTIONS {indexConfig: {
|
|
|
|
`vector.dimensions`: 1024,
|
|
|
|
`vector.similarity_function`: 'cosine'
|
|
|
|
}}
|
|
|
|
"""
|
|
|
|
)
|
|
|
|
|
|
|
|
async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
|
|
|
|
text = query.replace("\n", " ")
|
|
|
|
search_vector = (
|
|
|
|
(
|
|
|
|
await self.llm_client.client.embeddings.create(
|
|
|
|
input=[text], model="text-embedding-3-small"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
.data[0]
|
|
|
|
.embedding[:EMBEDDING_DIM]
|
|
|
|
)
|
2024-08-14 10:17:12 -04:00
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
edges = await edge_similarity_search(search_vector, self.driver)
|
|
|
|
nodes = await entity_fulltext_search(query, self.driver)
|
2024-08-14 10:17:12 -04:00
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
node_ids = [node.uuid for node in nodes]
|
2024-08-14 10:17:12 -04:00
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
for edge in edges:
|
|
|
|
node_ids.append(edge.source_node_uuid)
|
|
|
|
node_ids.append(edge.target_node_uuid)
|
2024-08-14 10:17:12 -04:00
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
node_ids = list(dict.fromkeys(node_ids))
|
2024-08-14 10:17:12 -04:00
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
context = await bfs(node_ids, self.driver)
|
2024-08-14 10:17:12 -04:00
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
return context
|