2024-08-13 14:35:43 -04:00
|
|
|
import asyncio
|
|
|
|
from datetime import datetime
|
|
|
|
import logging
|
2024-08-14 10:17:12 -04:00
|
|
|
from typing import Callable, Tuple
|
2024-08-13 14:35:43 -04:00
|
|
|
from neo4j import AsyncGraphDatabase
|
|
|
|
|
|
|
|
from core.nodes import SemanticNode, EpisodicNode, Node
|
2024-08-14 10:17:12 -04:00
|
|
|
from core.edges import SemanticEdge, Edge
|
|
|
|
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
|
2024-08-13 14:35:43 -04:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2024-08-14 10:17:12 -04:00
|
|
|
class LLMConfig:
|
|
|
|
"""Configuration for the language model"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
api_key: str,
|
|
|
|
model: str = "gpt-4o",
|
|
|
|
base_url: str = "https://api.openai.com",
|
|
|
|
):
|
|
|
|
self.base_url = base_url
|
|
|
|
self.api_key = api_key
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
|
2024-08-13 14:35:43 -04:00
|
|
|
class Graphiti:
|
2024-08-14 10:17:12 -04:00
|
|
|
def __init__(
|
|
|
|
self, uri: str, user: str, password: str, llm_config: LLMConfig | None
|
|
|
|
):
|
2024-08-13 14:35:43 -04:00
|
|
|
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
|
|
|
|
self.database = "neo4j"
|
2024-08-14 10:17:12 -04:00
|
|
|
if llm_config:
|
|
|
|
self.llm_config = llm_config
|
|
|
|
else:
|
|
|
|
self.llm_config = None
|
2024-08-13 14:35:43 -04:00
|
|
|
|
|
|
|
def close(self):
|
2024-08-14 10:17:12 -04:00
|
|
|
self.driver.close()
|
|
|
|
|
|
|
|
async def retrieve_episodes(
|
|
|
|
self, last_n: int, sources: list[str] | None = "messages"
|
|
|
|
) -> list[EpisodicNode]:
|
|
|
|
"""Retrieve the last n episodic nodes from the graph"""
|
|
|
|
...
|
|
|
|
|
|
|
|
# Utility function, to be removed from this class
|
|
|
|
async def clear_data(self): ...
|
|
|
|
|
|
|
|
async def search(self, query: str, config) -> (
|
|
|
|
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
|
|
|
|
(vec_nodes, vec_edges) = similarity_search(query, embedder)
|
|
|
|
(text_nodes, text_edges) = fulltext_search(query)
|
|
|
|
|
|
|
|
nodes = vec_nodes.extend(text_nodes)
|
|
|
|
edges = vec_edges.extend(text_edges)
|
|
|
|
|
|
|
|
results = bfs(nodes, edges, k=1)
|
|
|
|
|
|
|
|
episode_ids = ["Mode of episode ids"]
|
|
|
|
|
|
|
|
episodes = get_episodes(episode_ids[:episode_count])
|
|
|
|
|
|
|
|
return [(node, edges)], episodes
|
|
|
|
|
|
|
|
async def get_relevant_schema(self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]) -> (
|
|
|
|
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
|
|
|
|
pass
|
|
|
|
|
|
|
|
# Call llm with the specified messages, and return the response
|
|
|
|
# Will be used in the conjunction with a prompt library
|
|
|
|
async def generate_llm_response(self, messages: list[any]) -> str: ...
|
|
|
|
|
|
|
|
# Extract new edges from the episode
|
|
|
|
async def extract_new_edges(
|
|
|
|
self,
|
|
|
|
episode: EpisodicNode,
|
|
|
|
new_nodes: list[SemanticNode],
|
|
|
|
relevant_schema: dict[str, any],
|
|
|
|
previous_episodes: list[EpisodicNode],
|
|
|
|
) -> list[SemanticEdge]: ...
|
|
|
|
|
|
|
|
# Extract new nodes from the episode
|
|
|
|
async def extract_new_nodes(
|
|
|
|
self,
|
|
|
|
episode: EpisodicNode,
|
|
|
|
relevant_schema: dict[str, any],
|
|
|
|
previous_episodes: list[EpisodicNode],
|
|
|
|
) -> list[SemanticNode]: ...
|
|
|
|
|
|
|
|
# Invalidate edges that are no longer valid
|
|
|
|
async def invalidate_edges(
|
|
|
|
self,
|
|
|
|
episode: EpisodicNode,
|
|
|
|
new_nodes: list[SemanticNode],
|
|
|
|
new_edges: list[SemanticEdge],
|
|
|
|
relevant_schema: dict[str, any],
|
|
|
|
previous_episodes: list[EpisodicNode],
|
|
|
|
): ...
|
|
|
|
|
|
|
|
async def add_episode(
|
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
episode_body: str,
|
|
|
|
source_description: str,
|
|
|
|
reference_time: datetime = None,
|
|
|
|
episode_type="string",
|
|
|
|
success_callback: Callable | None = None,
|
|
|
|
error_callback: Callable | None = None,
|
|
|
|
):
|
|
|
|
"""Process an episode and update the graph"""
|
|
|
|
try:
|
|
|
|
nodes: list[Node] = []
|
|
|
|
edges: list[Edge] = []
|
|
|
|
previous_episodes = await self.retrieve_episodes(last_n=3)
|
|
|
|
episode = EpisodicNode()
|
|
|
|
await episode.save(self.driver)
|
|
|
|
relevant_schema = await self.retrieve_relevant_schema(episode.content)
|
|
|
|
new_nodes = await self.extract_new_nodes(
|
|
|
|
episode, relevant_schema, previous_episodes
|
|
|
|
)
|
|
|
|
nodes.extend(new_nodes)
|
|
|
|
new_edges = await self.extract_new_edges(
|
|
|
|
episode, new_nodes, relevant_schema, previous_episodes
|
|
|
|
)
|
|
|
|
edges.extend(new_edges)
|
|
|
|
episodic_edges = build_episodic_edges(nodes, episode, datetime.now())
|
|
|
|
edges.extend(episodic_edges)
|
|
|
|
|
|
|
|
invalidated_edges = await self.invalidate_edges(
|
|
|
|
episode, new_nodes, new_edges, relevant_schema, previous_episodes
|
|
|
|
)
|
|
|
|
|
|
|
|
edges.extend(invalidated_edges)
|
|
|
|
|
|
|
|
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
|
|
|
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
|
|
|
for node in nodes:
|
|
|
|
if isinstance(node, SemanticNode):
|
|
|
|
await node.update_summary(self.driver)
|
|
|
|
if success_callback:
|
|
|
|
await success_callback(episode)
|
|
|
|
except Exception as e:
|
|
|
|
if error_callback:
|
|
|
|
await error_callback(episode, e)
|
|
|
|
else:
|
|
|
|
raise e
|