graphiti/core/graphiti.py
Pavlo Paliychuk 83c7640d9c
chore: Initial draft of stubs (#2)
* chore: Initial draft of stubs

* updates

* chore: Add comments and mock implementation of the add_episode method

* chore: Add success and error callbacks

* stub updates

---------

Co-authored-by: prestonrasmussen <prasmuss15@gmail.com>
2024-08-14 10:17:12 -04:00

149 lines
4.8 KiB
Python

import asyncio
from datetime import datetime
import logging
from typing import Callable, Tuple
from neo4j import AsyncGraphDatabase
from core.nodes import SemanticNode, EpisodicNode, Node
from core.edges import SemanticEdge, Edge
from core.utils import bfs, similarity_search, fulltext_search, build_episodic_edges
logger = logging.getLogger(__name__)
class LLMConfig:
"""Configuration for the language model"""
def __init__(
self,
api_key: str,
model: str = "gpt-4o",
base_url: str = "https://api.openai.com",
):
self.base_url = base_url
self.api_key = api_key
self.model = model
class Graphiti:
def __init__(
self, uri: str, user: str, password: str, llm_config: LLMConfig | None
):
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = "neo4j"
if llm_config:
self.llm_config = llm_config
else:
self.llm_config = None
def close(self):
self.driver.close()
async def retrieve_episodes(
self, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
...
# Utility function, to be removed from this class
async def clear_data(self): ...
async def search(self, query: str, config) -> (
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
(vec_nodes, vec_edges) = similarity_search(query, embedder)
(text_nodes, text_edges) = fulltext_search(query)
nodes = vec_nodes.extend(text_nodes)
edges = vec_edges.extend(text_edges)
results = bfs(nodes, edges, k=1)
episode_ids = ["Mode of episode ids"]
episodes = get_episodes(episode_ids[:episode_count])
return [(node, edges)], episodes
async def get_relevant_schema(self, episode: EpisodicNode, previous_episodes: list[EpisodicNode]) -> (
list)[Tuple[SemanticNode, list[SemanticEdge]]]:
pass
# Call llm with the specified messages, and return the response
# Will be used in the conjunction with a prompt library
async def generate_llm_response(self, messages: list[any]) -> str: ...
# Extract new edges from the episode
async def extract_new_edges(
self,
episode: EpisodicNode,
new_nodes: list[SemanticNode],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[SemanticEdge]: ...
# Extract new nodes from the episode
async def extract_new_nodes(
self,
episode: EpisodicNode,
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
) -> list[SemanticNode]: ...
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[SemanticNode],
new_edges: list[SemanticEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime = None,
episode_type="string",
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""Process an episode and update the graph"""
try:
nodes: list[Node] = []
edges: list[Edge] = []
previous_episodes = await self.retrieve_episodes(last_n=3)
episode = EpisodicNode()
await episode.save(self.driver)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await self.extract_new_nodes(
episode, relevant_schema, previous_episodes
)
nodes.extend(new_nodes)
new_edges = await self.extract_new_edges(
episode, new_nodes, relevant_schema, previous_episodes
)
edges.extend(new_edges)
episodic_edges = build_episodic_edges(nodes, episode, datetime.now())
edges.extend(episodic_edges)
invalidated_edges = await self.invalidate_edges(
episode, new_nodes, new_edges, relevant_schema, previous_episodes
)
edges.extend(invalidated_edges)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
for node in nodes:
if isinstance(node, SemanticNode):
await node.update_summary(self.driver)
if success_callback:
await success_callback(episode)
except Exception as e:
if error_callback:
await error_callback(episode, e)
else:
raise e