graphiti/core/graphiti.py
Daniel Chalef 6d52be49f4
Add Apache License 2.0 boilerplate to all Python files (#30)
* Add Apache License 2.0 boilerplate to all Python files

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/getzep/graphiti?shareId=XXXX-XXXX-XXXX-XXXX).

* format

* format

* chore: Add Ellipsis configuration file
2024-08-23 16:01:33 -04:00

340 lines
12 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
import os
from datetime import datetime
from time import time
from typing import Callable
from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from core.edges import EntityEdge, EpisodicEdge
from core.llm_client import LLMClient, LLMConfig, OpenAIClient
from core.nodes import EntityNode, EpisodicNode
from core.search.search import SearchConfig, hybrid_search
from core.search.search_utils import (
get_relevant_edges,
get_relevant_nodes,
)
from core.utils import (
build_episodic_edges,
retrieve_episodes,
)
from core.utils.bulk_utils import (
BulkEpisode,
dedupe_edges_bulk,
dedupe_nodes_bulk,
extract_nodes_and_edges_bulk,
resolve_edge_pointers,
retrieve_previous_episodes_bulk,
)
from core.utils.maintenance.edge_operations import (
dedupe_extracted_edges,
extract_edges,
)
from core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
)
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
extract_edge_dates,
extract_node_edge_node_triplet,
invalidate_edges,
prepare_edges_for_invalidation,
)
logger = logging.getLogger(__name__)
load_dotenv()
class Graphiti:
def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None):
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password))
self.database = 'neo4j'
if llm_client:
self.llm_client = llm_client
else:
self.llm_client = OpenAIClient(
LLMConfig(
api_key=os.getenv('OPENAI_API_KEY', default=''),
model='gpt-4o-mini',
base_url='https://api.openai.com/v1',
)
)
def close(self):
self.driver.close()
async def build_indices_and_constraints(self):
await build_indices_and_constraints(self.driver)
async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
return await retrieve_episodes(self.driver, reference_time, last_n)
async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""Process an episode and update the graph"""
try:
start = time()
nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.get_embedder()
now = datetime.now()
previous_episodes = await self.retrieve_episodes(reference_time)
episode = EpisodicNode(
name=name,
labels=[],
source='messages',
content=episode_body,
source_description=source_description,
created_at=now,
valid_at=reference_time,
)
extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
# Calculate Embeddings
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
nodes.extend(touched_nodes)
extracted_edges = await extract_edges(
self.llm_client, episode, touched_nodes, previous_episodes
)
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
existing_edges = await get_relevant_edges(extracted_edges, self.driver)
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
# deduped_edges = await dedupe_extracted_edges_v2(
# self.llm_client,
# extract_node_and_edge_triplets(extracted_edges, nodes),
# extract_node_and_edge_triplets(existing_edges, nodes),
# )
deduped_edges = await dedupe_extracted_edges(
self.llm_client,
extracted_edges,
existing_edges,
)
edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
for edge in deduped_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
) = prepare_edges_for_invalidation(
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
)
invalidated_edges = await invalidate_edges(
self.llm_client,
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
episode,
previous_episodes,
)
for edge in invalidated_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
edges_to_save = invalidated_edges
# There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one
for deduped_edge in deduped_edges:
if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]:
edges_to_save.append(deduped_edge)
for deduped_edge in deduped_edges:
triplet = extract_node_edge_node_triplet(deduped_edge, nodes)
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client, triplet, episode.valid_at, episode, previous_episodes
)
deduped_edge.valid_at = valid_at
deduped_edge.invalid_at = invalid_at
entity_edges.extend(edges_to_save)
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]
logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
involved_nodes,
episode,
now,
)
)
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
logger.info(f'Built episodic edges: {episodic_edges}')
# Future optimization would be using batch operations to save nodes and edges
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
end = time()
logger.info(f'Completed add_episode in {(end-start) * 1000} ms')
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)
if success_callback:
await success_callback(episode)
except Exception as e:
if error_callback:
await error_callback(episode, e)
else:
raise e
async def add_episode_bulk(
self,
bulk_episodes: list[BulkEpisode],
):
try:
start = time()
embedder = self.llm_client.get_embedder()
now = datetime.now()
episodes = [
EpisodicNode(
name=episode.name,
labels=[],
source='messages',
content=episode.content,
source_description=episode.source_description,
created_at=now,
valid_at=episode.reference_time,
)
for episode in bulk_episodes
]
# Save all the episodes
await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
# Get previous episode context for each episode
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
# Extract all nodes and edges
(
extracted_nodes,
extracted_edges,
episodic_edges,
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
# Generate embeddings
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes],
*[edge.generate_embedding(embedder) for edge in extracted_edges],
)
# Dedupe extracted nodes
nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes)
# save nodes to KG
await asyncio.gather(*[node.save(self.driver) for node in nodes])
# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
extracted_edges, uuid_map
)
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
episodic_edges, uuid_map
)
# save episodic edges to KG
await asyncio.gather(
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
)
# Dedupe extracted edges
edges = await dedupe_edges_bulk(
self.driver, self.llm_client, extracted_edges_with_resolved_pointers
)
logger.info(f'extracted edge length: {len(edges)}')
# invalidate edges
# save edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
end = time()
logger.info(f'Completed add_episode_bulk in {(end-start) * 1000} ms')
except Exception as e:
raise e
async def search(self, query: str, num_results=10):
search_config = SearchConfig(num_episodes=0, num_results=num_results)
edges = (
await hybrid_search(
self.driver,
self.llm_client.get_embedder(),
query,
datetime.now(),
search_config,
)
).edges
facts = [edge.fact for edge in edges]
return facts
async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config
)