graphiti/core/graphiti.py
Pavlo Paliychuk ad552b527e
Cleanup maintenance utilities + add podcast runner (#5)
* chore: Fix minor issues with episodic edge building + cleanup

* feat: Port podcast runner

* feat: Port podcast runner
2024-08-16 09:29:57 -04:00

190 lines
7.0 KiB
Python

import asyncio
from datetime import datetime
import logging
from typing import Callable, LiteralString, Tuple
from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv
import os
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge
from core.utils import (
build_episodic_edges,
retrieve_relevant_schema,
extract_new_edges,
extract_new_nodes,
clear_data,
retrieve_episodes,
)
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
logger = logging.getLogger(__name__)
load_dotenv()
class Graphiti:
def __init__(
self, uri: str, user: str, password: str, llm_client: LLMClient | None = None
):
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = "neo4j"
if llm_client:
self.llm_client = llm_client
else:
self.llm_client = OpenAIClient(
LLMConfig(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o",
base_url="https://api.openai.com/v1",
)
)
def close(self):
self.driver.close()
async def retrieve_episodes(
self, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
return await retrieve_episodes(self.driver, last_n, sources)
async def retrieve_relevant_schema(self, query: str = None) -> dict[str, any]:
"""Retrieve relevant nodes and edges to a specific query"""
return await retrieve_relevant_schema(self.driver, query)
...
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime = None,
episode_type="string",
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""Process an episode and update the graph"""
try:
nodes: list[Node] = []
edges: list[Edge] = []
previous_episodes = await self.retrieve_episodes(last_n=3)
episode = EpisodicNode(
name=name,
labels=[],
source="messages",
content=episode_body,
source_description=source_description,
created_at=datetime.now(),
valid_at=reference_time,
)
# await episode.save(self.driver)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await extract_new_nodes(
self.llm_client, episode, relevant_schema, previous_episodes
)
nodes.extend(new_nodes)
new_edges, affected_nodes = await extract_new_edges(
self.llm_client, episode, new_nodes, relevant_schema, previous_episodes
)
edges.extend(new_edges)
episodic_edges = build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
list(set(nodes + affected_nodes)),
episode,
datetime.now(),
)
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
nodes.append(episode)
logger.info(f"Built episodic edges: {episodic_edges}")
edges.extend(episodic_edges)
# invalidated_edges = await self.invalidate_edges(
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
# )
# edges.extend(invalidated_edges)
# Future optimization would be using batch operations to save nodes and edges
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)
if success_callback:
await success_callback(episode)
except Exception as e:
if error_callback:
await error_callback(episode, e)
else:
raise e
async def build_indices(self):
index_queries: list[LiteralString] = [
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.name)",
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.created_at)",
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.expired_at)",
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.valid_at)",
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.invalid_at)",
]
# Add the range indices
for query in index_queries:
await self.driver.execute_query(query)
# Add the entity indices
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
"""
)
await self.driver.execute_query(
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
"""
)
async def search(
self, query: str, config
) -> (list)[tuple[EntityNode, list[EntityEdge]]]:
(vec_nodes, vec_edges) = similarity_search(query, embedder)
(text_nodes, text_edges) = fulltext_search(query)
nodes = vec_nodes.extend(text_nodes)
edges = vec_edges.extend(text_edges)
results = bfs(nodes, edges, k=1)
episode_ids = ["Mode of episode ids"]
episodes = get_episodes(episode_ids[:episode_count])
return [(node, edges)], episodes
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...