2024-08-23 13:01:33 -07:00
|
|
|
"""
|
|
|
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
|
2024-09-03 13:25:52 -04:00
|
|
|
import asyncio
|
2024-08-22 12:26:13 -07:00
|
|
|
import logging
|
2024-08-15 12:03:41 -04:00
|
|
|
from datetime import datetime
|
2024-08-21 12:03:32 -04:00
|
|
|
from time import time
|
2024-08-22 12:26:13 -07:00
|
|
|
from typing import List
|
2024-08-15 12:03:41 -04:00
|
|
|
|
2024-08-25 10:07:50 -07:00
|
|
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
|
|
|
from graphiti_core.llm_client import LLMClient
|
|
|
|
|
from graphiti_core.nodes import EntityNode, EpisodicNode
|
|
|
|
|
from graphiti_core.prompts import prompt_library
|
2024-09-05 12:05:44 -04:00
|
|
|
from graphiti_core.utils.maintenance.temporal_operations import (
|
|
|
|
|
extract_edge_dates,
|
|
|
|
|
get_edge_contradictions,
|
|
|
|
|
)
|
2024-08-15 12:03:41 -04:00
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_episodic_edges(
|
2024-08-23 14:18:45 -04:00
|
|
|
entity_nodes: List[EntityNode],
|
|
|
|
|
episode: EpisodicNode,
|
|
|
|
|
created_at: datetime,
|
2024-08-15 12:03:41 -04:00
|
|
|
) -> List[EpisodicEdge]:
|
2024-08-23 14:18:45 -04:00
|
|
|
edges: List[EpisodicEdge] = []
|
2024-08-15 12:03:41 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
for node in entity_nodes:
|
|
|
|
|
edge = EpisodicEdge(
|
|
|
|
|
source_node_uuid=episode.uuid,
|
|
|
|
|
target_node_uuid=node.uuid,
|
|
|
|
|
created_at=created_at,
|
|
|
|
|
)
|
|
|
|
|
edges.append(edge)
|
2024-08-15 12:03:41 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
return edges
|
2024-08-15 12:03:41 -04:00
|
|
|
|
|
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
async def extract_edges(
|
2024-08-23 14:18:45 -04:00
|
|
|
llm_client: LLMClient,
|
|
|
|
|
episode: EpisodicNode,
|
|
|
|
|
nodes: list[EntityNode],
|
|
|
|
|
previous_episodes: list[EpisodicNode],
|
2024-08-18 13:22:31 -04:00
|
|
|
) -> list[EntityEdge]:
|
2024-08-23 14:18:45 -04:00
|
|
|
start = time()
|
|
|
|
|
|
|
|
|
|
# Prepare context for LLM
|
|
|
|
|
context = {
|
|
|
|
|
'episode_content': episode.content,
|
|
|
|
|
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
|
|
|
|
|
'nodes': [
|
|
|
|
|
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes
|
|
|
|
|
],
|
|
|
|
|
'previous_episodes': [
|
|
|
|
|
{
|
|
|
|
|
'content': ep.content,
|
|
|
|
|
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
|
|
|
|
|
}
|
|
|
|
|
for ep in previous_episodes
|
|
|
|
|
],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llm_response = await llm_client.generate_response(prompt_library.extract_edges.v2(context))
|
|
|
|
|
edges_data = llm_response.get('edges', [])
|
|
|
|
|
|
|
|
|
|
end = time()
|
|
|
|
|
logger.info(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
|
|
|
|
|
|
|
|
|
|
# Convert the extracted data into EntityEdge objects
|
|
|
|
|
edges = []
|
|
|
|
|
for edge_data in edges_data:
|
|
|
|
|
if edge_data['target_node_uuid'] and edge_data['source_node_uuid']:
|
|
|
|
|
edge = EntityEdge(
|
|
|
|
|
source_node_uuid=edge_data['source_node_uuid'],
|
|
|
|
|
target_node_uuid=edge_data['target_node_uuid'],
|
|
|
|
|
name=edge_data['relation_type'],
|
|
|
|
|
fact=edge_data['fact'],
|
|
|
|
|
episodes=[episode.uuid],
|
|
|
|
|
created_at=datetime.now(),
|
|
|
|
|
valid_at=None,
|
|
|
|
|
invalid_at=None,
|
|
|
|
|
)
|
|
|
|
|
edges.append(edge)
|
|
|
|
|
logger.info(
|
|
|
|
|
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return edges
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
|
2024-08-22 18:09:44 -04:00
|
|
|
def create_edge_identifier(
|
2024-08-23 14:18:45 -04:00
|
|
|
source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
|
2024-08-22 18:09:44 -04:00
|
|
|
) -> str:
|
2024-08-23 14:18:45 -04:00
|
|
|
return f'{source_node.name}-{edge.name}-{target_node.name}'
|
2024-08-22 18:09:44 -04:00
|
|
|
|
|
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
async def dedupe_extracted_edges(
|
2024-08-23 14:18:45 -04:00
|
|
|
llm_client: LLMClient,
|
|
|
|
|
extracted_edges: list[EntityEdge],
|
|
|
|
|
existing_edges: list[EntityEdge],
|
2024-08-18 13:22:31 -04:00
|
|
|
) -> list[EntityEdge]:
|
2024-08-23 14:18:45 -04:00
|
|
|
# Create edge map
|
2024-09-03 13:25:52 -04:00
|
|
|
edge_map: dict[str, EntityEdge] = {}
|
|
|
|
|
for edge in existing_edges:
|
2024-08-23 14:18:45 -04:00
|
|
|
edge_map[edge.uuid] = edge
|
|
|
|
|
|
|
|
|
|
# Prepare context for LLM
|
|
|
|
|
context = {
|
|
|
|
|
'extracted_edges': [
|
|
|
|
|
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
|
|
|
|
|
],
|
|
|
|
|
'existing_edges': [
|
|
|
|
|
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
|
|
|
|
|
],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
2024-09-03 13:25:52 -04:00
|
|
|
duplicate_data = llm_response.get('duplicates', [])
|
|
|
|
|
logger.info(f'Extracted unique edges: {duplicate_data}')
|
|
|
|
|
|
|
|
|
|
duplicate_uuid_map: dict[str, str] = {}
|
|
|
|
|
for duplicate in duplicate_data:
|
|
|
|
|
uuid_value = duplicate['duplicate_of']
|
|
|
|
|
duplicate_uuid_map[duplicate['uuid']] = uuid_value
|
2024-08-23 14:18:45 -04:00
|
|
|
|
|
|
|
|
# Get full edge data
|
2024-09-03 13:25:52 -04:00
|
|
|
edges: list[EntityEdge] = []
|
|
|
|
|
for edge in extracted_edges:
|
|
|
|
|
if edge.uuid in duplicate_uuid_map:
|
|
|
|
|
existing_uuid = duplicate_uuid_map[edge.uuid]
|
|
|
|
|
existing_edge = edge_map[existing_uuid]
|
|
|
|
|
edges.append(existing_edge)
|
|
|
|
|
else:
|
|
|
|
|
edges.append(edge)
|
2024-08-23 14:18:45 -04:00
|
|
|
|
|
|
|
|
return edges
|
2024-08-21 12:03:32 -04:00
|
|
|
|
|
|
|
|
|
2024-09-03 13:25:52 -04:00
|
|
|
async def resolve_extracted_edges(
|
|
|
|
|
llm_client: LLMClient,
|
|
|
|
|
extracted_edges: list[EntityEdge],
|
2024-09-05 12:05:44 -04:00
|
|
|
related_edges_lists: list[list[EntityEdge]],
|
2024-09-03 13:25:52 -04:00
|
|
|
existing_edges_lists: list[list[EntityEdge]],
|
2024-09-05 12:05:44 -04:00
|
|
|
current_episode: EpisodicNode,
|
|
|
|
|
previous_episodes: list[EpisodicNode],
|
|
|
|
|
) -> tuple[list[EntityEdge], list[EntityEdge]]:
|
|
|
|
|
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
|
|
|
|
|
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
|
2024-09-03 13:25:52 -04:00
|
|
|
await asyncio.gather(
|
|
|
|
|
*[
|
2024-09-05 12:05:44 -04:00
|
|
|
resolve_extracted_edge(
|
|
|
|
|
llm_client,
|
|
|
|
|
extracted_edge,
|
|
|
|
|
related_edges,
|
|
|
|
|
existing_edges,
|
|
|
|
|
current_episode,
|
|
|
|
|
previous_episodes,
|
|
|
|
|
)
|
|
|
|
|
for extracted_edge, related_edges, existing_edges in zip(
|
|
|
|
|
extracted_edges, related_edges_lists, existing_edges_lists
|
|
|
|
|
)
|
2024-09-03 13:25:52 -04:00
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
2024-09-05 12:05:44 -04:00
|
|
|
resolved_edges: list[EntityEdge] = []
|
|
|
|
|
invalidated_edges: list[EntityEdge] = []
|
|
|
|
|
for result in results:
|
|
|
|
|
resolved_edge = result[0]
|
|
|
|
|
invalidated_edge_chunk = result[1]
|
|
|
|
|
|
|
|
|
|
resolved_edges.append(resolved_edge)
|
|
|
|
|
invalidated_edges.extend(invalidated_edge_chunk)
|
|
|
|
|
|
|
|
|
|
return resolved_edges, invalidated_edges
|
2024-09-03 13:25:52 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def resolve_extracted_edge(
|
2024-09-05 12:05:44 -04:00
|
|
|
llm_client: LLMClient,
|
|
|
|
|
extracted_edge: EntityEdge,
|
|
|
|
|
related_edges: list[EntityEdge],
|
|
|
|
|
existing_edges: list[EntityEdge],
|
|
|
|
|
current_episode: EpisodicNode,
|
|
|
|
|
previous_episodes: list[EpisodicNode],
|
|
|
|
|
) -> tuple[EntityEdge, list[EntityEdge]]:
|
|
|
|
|
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
|
|
|
|
|
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
|
|
|
|
|
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
|
|
|
|
|
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
now = datetime.now()
|
|
|
|
|
|
|
|
|
|
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
|
|
|
|
|
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
|
|
|
|
|
if invalid_at is not None and resolved_edge.expired_at is None:
|
|
|
|
|
resolved_edge.expired_at = now
|
|
|
|
|
|
|
|
|
|
# Determine if the new_edge needs to be expired
|
|
|
|
|
if resolved_edge.expired_at is None:
|
|
|
|
|
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
|
|
|
|
for candidate in invalidation_candidates:
|
|
|
|
|
if (
|
|
|
|
|
candidate.valid_at is not None and resolved_edge.valid_at is not None
|
|
|
|
|
) and candidate.valid_at > resolved_edge.valid_at:
|
|
|
|
|
# Expire new edge since we have information about more recent events
|
|
|
|
|
resolved_edge.invalid_at = candidate.valid_at
|
|
|
|
|
resolved_edge.expired_at = now
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# Determine which contradictory edges need to be expired
|
|
|
|
|
invalidated_edges: list[EntityEdge] = []
|
|
|
|
|
for edge in invalidation_candidates:
|
|
|
|
|
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
|
|
|
|
|
if (
|
|
|
|
|
edge.invalid_at is not None
|
|
|
|
|
and resolved_edge.valid_at is not None
|
|
|
|
|
and edge.invalid_at < resolved_edge.valid_at
|
|
|
|
|
) or (
|
|
|
|
|
edge.valid_at is not None
|
|
|
|
|
and resolved_edge.invalid_at is not None
|
|
|
|
|
and resolved_edge.invalid_at < edge.valid_at
|
|
|
|
|
):
|
|
|
|
|
continue
|
|
|
|
|
# New edge invalidates edge
|
|
|
|
|
elif (
|
|
|
|
|
edge.valid_at is not None
|
|
|
|
|
and resolved_edge.valid_at is not None
|
|
|
|
|
and edge.valid_at < resolved_edge.valid_at
|
|
|
|
|
):
|
|
|
|
|
edge.invalid_at = resolved_edge.valid_at
|
|
|
|
|
edge.expired_at = edge.expired_at if edge.expired_at is not None else now
|
|
|
|
|
invalidated_edges.append(edge)
|
|
|
|
|
|
|
|
|
|
return resolved_edge, invalidated_edges
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def dedupe_extracted_edge(
|
|
|
|
|
llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge]
|
2024-09-03 13:25:52 -04:00
|
|
|
) -> EntityEdge:
|
|
|
|
|
start = time()
|
|
|
|
|
|
|
|
|
|
# Prepare context for LLM
|
2024-09-05 12:05:44 -04:00
|
|
|
related_edges_context = [
|
|
|
|
|
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in related_edges
|
2024-09-03 13:25:52 -04:00
|
|
|
]
|
|
|
|
|
|
|
|
|
|
extracted_edge_context = {
|
|
|
|
|
'uuid': extracted_edge.uuid,
|
|
|
|
|
'name': extracted_edge.name,
|
|
|
|
|
'fact': extracted_edge.fact,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
context = {
|
2024-09-05 12:05:44 -04:00
|
|
|
'related_edges': related_edges_context,
|
2024-09-03 13:25:52 -04:00
|
|
|
'extracted_edges': extracted_edge_context,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v3(context))
|
|
|
|
|
|
|
|
|
|
is_duplicate: bool = llm_response.get('is_duplicate', False)
|
|
|
|
|
uuid: str | None = llm_response.get('uuid', None)
|
|
|
|
|
|
|
|
|
|
edge = extracted_edge
|
|
|
|
|
if is_duplicate:
|
2024-09-05 12:05:44 -04:00
|
|
|
for existing_edge in related_edges:
|
2024-09-03 13:25:52 -04:00
|
|
|
if existing_edge.uuid != uuid:
|
|
|
|
|
continue
|
|
|
|
|
edge = existing_edge
|
|
|
|
|
|
|
|
|
|
end = time()
|
|
|
|
|
logger.info(
|
2024-09-05 12:05:44 -04:00
|
|
|
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
|
2024-09-03 13:25:52 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return edge
|
|
|
|
|
|
|
|
|
|
|
2024-08-21 12:03:32 -04:00
|
|
|
async def dedupe_edge_list(
|
2024-08-23 14:18:45 -04:00
|
|
|
llm_client: LLMClient,
|
|
|
|
|
edges: list[EntityEdge],
|
2024-08-21 12:03:32 -04:00
|
|
|
) -> list[EntityEdge]:
|
2024-08-23 14:18:45 -04:00
|
|
|
start = time()
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
# Create edge map
|
|
|
|
|
edge_map = {}
|
|
|
|
|
for edge in edges:
|
|
|
|
|
edge_map[edge.uuid] = edge
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
# Prepare context for LLM
|
|
|
|
|
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
llm_response = await llm_client.generate_response(
|
|
|
|
|
prompt_library.dedupe_edges.edge_list(context)
|
|
|
|
|
)
|
|
|
|
|
unique_edges_data = llm_response.get('unique_facts', [])
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
end = time()
|
|
|
|
|
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
# Get full edge data
|
|
|
|
|
unique_edges = []
|
|
|
|
|
for edge_data in unique_edges_data:
|
|
|
|
|
uuid = edge_data['uuid']
|
|
|
|
|
edge = edge_map[uuid]
|
|
|
|
|
edge.fact = edge_data['fact']
|
|
|
|
|
unique_edges.append(edge)
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
return unique_edges
|