graphiti/core/graphiti.py

149 lines
4.8 KiB
Python
Raw Normal View History

2024-08-13 14:35:43 -04:00
import asyncio
from datetime import datetime
import logging
from typing import Callable, Tuple
2024-08-13 14:35:43 -04:00
from neo4j import AsyncGraphDatabase
from core.nodes import SemanticNode, EpisodicNode, Node
from core.edges import SemanticEdge, Edge
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
2024-08-13 14:35:43 -04:00
logger = logging.getLogger(__name__)
class LLMConfig:
"""Configuration for the language model"""
def __init__(
self,
api_key: str,
model: str = "gpt-4o",
base_url: str = "https://api.openai.com",
):
self.base_url = base_url
self.api_key = api_key
self.model = model
2024-08-13 14:35:43 -04:00
class Graphiti:
def __init__(
self, uri: str, user: str, password: str, llm_config: LLMConfig | None
):
2024-08-13 14:35:43 -04:00
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = "neo4j"
if llm_config:
self.llm_config = llm_config
else:
self.llm_config = None
2024-08-13 14:35:43 -04:00
def close(self):
self.driver.close()
async def retrieve_episodes(
self, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
...
# Utility function, to be removed from this class
async def clear_data(self): ...
async def search(self, query: str, config) -> (
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
(vec_nodes, vec_edges) = similarity_search(query, embedder)
(text_nodes, text_edges) = fulltext_search(query)
nodes = vec_nodes.extend(text_nodes)
edges = vec_edges.extend(text_edges)
results = bfs(nodes, edges, k=1)
episode_ids = ["Mode of episode ids"]
episodes = get_episodes(episode_ids[:episode_count])
return [(node, edges)], episodes
async def get_relevant_schema(self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]) -> (
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
pass
# Call llm with the specified messages, and return the response
# Will be used in the conjunction with a prompt library
async def generate_llm_response(self, messages: list[any]) -> str: ...
# Extract new edges from the episode
async def extract_new_edges(
self,
episode: EpisodicNode,
new_nodes: list[SemanticNode],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[SemanticEdge]: ...
# Extract new nodes from the episode
async def extract_new_nodes(
self,
episode: EpisodicNode,
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[SemanticNode]: ...
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[SemanticNode],
new_edges: list[SemanticEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime = None,
episode_type="string",
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""Process an episode and update the graph"""
try:
nodes: list[Node] = []
edges: list[Edge] = []
previous_episodes = await self.retrieve_episodes(last_n=3)
episode = EpisodicNode()
await episode.save(self.driver)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await self.extract_new_nodes(
episode, relevant_schema, previous_episodes
)
nodes.extend(new_nodes)
new_edges = await self.extract_new_edges(
episode, new_nodes, relevant_schema, previous_episodes
)
edges.extend(new_edges)
episodic_edges = build_episodic_edges(nodes, episode, datetime.now())
edges.extend(episodic_edges)
invalidated_edges = await self.invalidate_edges(
episode, new_nodes, new_edges, relevant_schema, previous_episodes
)
edges.extend(invalidated_edges)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
for node in nodes:
if isinstance(node, SemanticNode):
await node.update_summary(self.driver)
if success_callback:
await success_callback(episode)
except Exception as e:
if error_callback:
await error_callback(episode, e)
else:
raise e