2024-08-15 12:03:41 -04:00
|
|
|
import json
|
2024-08-22 12:26:13 -07:00
|
|
|
import logging
|
2024-08-15 12:03:41 -04:00
|
|
|
from datetime import datetime
|
2024-08-21 12:03:32 -04:00
|
|
|
from time import time
|
2024-08-22 12:26:13 -07:00
|
|
|
from typing import List
|
2024-08-15 12:03:41 -04:00
|
|
|
|
2024-08-22 12:26:13 -07:00
|
|
|
from core.edges import EntityEdge, EpisodicEdge
|
|
|
|
from core.llm_client import LLMClient
|
2024-08-15 12:03:41 -04:00
|
|
|
from core.nodes import EntityNode, EpisodicNode
|
|
|
|
from core.prompts import prompt_library
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def build_episodic_edges(
|
2024-08-22 13:06:42 -07:00
|
|
|
entity_nodes: List[EntityNode],
|
|
|
|
episode: EpisodicNode,
|
|
|
|
created_at: datetime,
|
2024-08-15 12:03:41 -04:00
|
|
|
) -> List[EpisodicEdge]:
|
2024-08-22 13:06:42 -07:00
|
|
|
edges: List[EpisodicEdge] = []
|
2024-08-15 12:03:41 -04:00
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
for node in entity_nodes:
|
|
|
|
edge = EpisodicEdge(
|
|
|
|
source_node_uuid=episode.uuid,
|
|
|
|
target_node_uuid=node.uuid,
|
|
|
|
created_at=created_at,
|
|
|
|
)
|
|
|
|
edges.append(edge)
|
2024-08-15 12:03:41 -04:00
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
return edges
|
2024-08-15 12:03:41 -04:00
|
|
|
|
|
|
|
|
|
|
|
async def extract_new_edges(
|
2024-08-22 13:06:42 -07:00
|
|
|
llm_client: LLMClient,
|
|
|
|
episode: EpisodicNode,
|
|
|
|
new_nodes: list[EntityNode],
|
|
|
|
relevant_schema: dict[str, any],
|
|
|
|
previous_episodes: list[EpisodicNode],
|
2024-08-16 09:29:57 -04:00
|
|
|
) -> tuple[list[EntityEdge], list[EntityNode]]:
|
2024-08-22 13:06:42 -07:00
|
|
|
# Prepare context for LLM
|
|
|
|
context = {
|
|
|
|
'episode_content': episode.content,
|
|
|
|
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
|
|
|
|
'relevant_schema': json.dumps(relevant_schema, indent=2),
|
|
|
|
'new_nodes': [{'name': node.name, 'summary': node.summary} for node in new_nodes],
|
|
|
|
'previous_episodes': [
|
|
|
|
{
|
|
|
|
'content': ep.content,
|
|
|
|
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
|
|
|
|
}
|
|
|
|
for ep in previous_episodes
|
|
|
|
],
|
|
|
|
}
|
|
|
|
|
|
|
|
llm_response = await llm_client.generate_response(prompt_library.extract_edges.v1(context))
|
|
|
|
new_edges_data = llm_response.get('new_edges', [])
|
|
|
|
logger.info(f'Extracted new edges: {new_edges_data}')
|
|
|
|
|
|
|
|
# Convert the extracted data into EntityEdge objects
|
|
|
|
new_edges = []
|
|
|
|
for edge_data in new_edges_data:
|
|
|
|
source_node = next(
|
|
|
|
(node for node in new_nodes if node.name == edge_data['source_node']),
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
target_node = next(
|
|
|
|
(node for node in new_nodes if node.name == edge_data['target_node']),
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
|
|
|
|
# If source or target is not in new_nodes, check if it's an existing node
|
|
|
|
if source_node is None and edge_data['source_node'] in relevant_schema['nodes']:
|
|
|
|
existing_node_data = relevant_schema['nodes'][edge_data['source_node']]
|
|
|
|
source_node = EntityNode(
|
|
|
|
uuid=existing_node_data['uuid'],
|
|
|
|
name=edge_data['source_node'],
|
|
|
|
labels=[existing_node_data['label']],
|
|
|
|
summary='',
|
|
|
|
created_at=datetime.now(),
|
|
|
|
)
|
|
|
|
if target_node is None and edge_data['target_node'] in relevant_schema['nodes']:
|
|
|
|
existing_node_data = relevant_schema['nodes'][edge_data['target_node']]
|
|
|
|
target_node = EntityNode(
|
|
|
|
uuid=existing_node_data['uuid'],
|
|
|
|
name=edge_data['target_node'],
|
|
|
|
labels=[existing_node_data['label']],
|
|
|
|
summary='',
|
|
|
|
created_at=datetime.now(),
|
|
|
|
)
|
|
|
|
|
|
|
|
if (
|
|
|
|
source_node
|
|
|
|
and target_node
|
|
|
|
and not (
|
|
|
|
source_node.name.startswith('Message') or target_node.name.startswith('Message')
|
|
|
|
)
|
|
|
|
):
|
|
|
|
valid_at = (
|
|
|
|
datetime.fromisoformat(edge_data['valid_at'])
|
|
|
|
if edge_data['valid_at']
|
|
|
|
else episode.valid_at or datetime.now()
|
|
|
|
)
|
|
|
|
invalid_at = (
|
|
|
|
datetime.fromisoformat(edge_data['invalid_at']) if edge_data['invalid_at'] else None
|
|
|
|
)
|
|
|
|
|
|
|
|
new_edge = EntityEdge(
|
|
|
|
source_node=source_node,
|
|
|
|
target_node=target_node,
|
|
|
|
name=edge_data['relation_type'],
|
|
|
|
fact=edge_data['fact'],
|
|
|
|
episodes=[episode.uuid],
|
|
|
|
created_at=datetime.now(),
|
|
|
|
valid_at=valid_at,
|
|
|
|
invalid_at=invalid_at,
|
|
|
|
)
|
|
|
|
new_edges.append(new_edge)
|
|
|
|
logger.info(
|
|
|
|
f'Created new edge: {new_edge.name} from {source_node.name} (UUID: {source_node.uuid}) to {target_node.name} (UUID: {target_node.uuid})'
|
|
|
|
)
|
|
|
|
|
|
|
|
affected_nodes = set()
|
|
|
|
|
|
|
|
for edge in new_edges:
|
|
|
|
affected_nodes.add(edge.source_node)
|
|
|
|
affected_nodes.add(edge.target_node)
|
|
|
|
return new_edges, list(affected_nodes)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
|
|
|
async def extract_edges(
|
2024-08-22 13:06:42 -07:00
|
|
|
llm_client: LLMClient,
|
|
|
|
episode: EpisodicNode,
|
|
|
|
nodes: list[EntityNode],
|
|
|
|
previous_episodes: list[EpisodicNode],
|
2024-08-18 13:22:31 -04:00
|
|
|
) -> list[EntityEdge]:
|
2024-08-22 13:06:42 -07:00
|
|
|
start = time()
|
|
|
|
|
|
|
|
# Prepare context for LLM
|
|
|
|
context = {
|
|
|
|
'episode_content': episode.content,
|
|
|
|
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
|
|
|
|
'nodes': [
|
|
|
|
{'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes
|
|
|
|
],
|
|
|
|
'previous_episodes': [
|
|
|
|
{
|
|
|
|
'content': ep.content,
|
|
|
|
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
|
|
|
|
}
|
|
|
|
for ep in previous_episodes
|
|
|
|
],
|
|
|
|
}
|
|
|
|
|
|
|
|
llm_response = await llm_client.generate_response(prompt_library.extract_edges.v2(context))
|
|
|
|
edges_data = llm_response.get('edges', [])
|
|
|
|
|
|
|
|
end = time()
|
|
|
|
logger.info(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
|
|
|
|
|
|
|
|
# Convert the extracted data into EntityEdge objects
|
|
|
|
edges = []
|
|
|
|
for edge_data in edges_data:
|
|
|
|
if edge_data['target_node_uuid'] and edge_data['source_node_uuid']:
|
|
|
|
edge = EntityEdge(
|
|
|
|
source_node_uuid=edge_data['source_node_uuid'],
|
|
|
|
target_node_uuid=edge_data['target_node_uuid'],
|
|
|
|
name=edge_data['relation_type'],
|
|
|
|
fact=edge_data['fact'],
|
|
|
|
episodes=[episode.uuid],
|
|
|
|
created_at=datetime.now(),
|
|
|
|
valid_at=None,
|
|
|
|
invalid_at=None,
|
|
|
|
)
|
|
|
|
edges.append(edge)
|
|
|
|
logger.info(
|
|
|
|
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
|
|
|
|
)
|
|
|
|
|
|
|
|
return edges
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
|
|
|
async def dedupe_extracted_edges(
|
2024-08-22 13:06:42 -07:00
|
|
|
llm_client: LLMClient,
|
|
|
|
extracted_edges: list[EntityEdge],
|
|
|
|
existing_edges: list[EntityEdge],
|
2024-08-18 13:22:31 -04:00
|
|
|
) -> list[EntityEdge]:
|
2024-08-22 13:06:42 -07:00
|
|
|
# Create edge map
|
|
|
|
edge_map = {}
|
|
|
|
for edge in existing_edges:
|
|
|
|
edge_map[edge.fact] = edge
|
|
|
|
for edge in extracted_edges:
|
|
|
|
if edge.fact in edge_map:
|
|
|
|
continue
|
|
|
|
edge_map[edge.fact] = edge
|
|
|
|
|
|
|
|
# Prepare context for LLM
|
|
|
|
context = {
|
|
|
|
'extracted_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
|
|
|
|
'existing_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
|
|
|
|
}
|
|
|
|
|
|
|
|
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
|
|
|
|
new_edges_data = llm_response.get('new_edges', [])
|
|
|
|
logger.info(f'Extracted new edges: {new_edges_data}')
|
|
|
|
|
|
|
|
# Get full edge data
|
|
|
|
edges = []
|
|
|
|
for edge_data in new_edges_data:
|
|
|
|
edge = edge_map[edge_data['fact']]
|
|
|
|
edges.append(edge)
|
|
|
|
|
|
|
|
return edges
|
2024-08-21 12:03:32 -04:00
|
|
|
|
|
|
|
|
|
|
|
async def dedupe_edge_list(
|
2024-08-22 13:06:42 -07:00
|
|
|
llm_client: LLMClient,
|
|
|
|
edges: list[EntityEdge],
|
2024-08-21 12:03:32 -04:00
|
|
|
) -> list[EntityEdge]:
|
2024-08-22 13:06:42 -07:00
|
|
|
start = time()
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
# Create edge map
|
|
|
|
edge_map = {}
|
|
|
|
for edge in edges:
|
|
|
|
edge_map[edge.fact] = edge
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
# Prepare context for LLM
|
|
|
|
context = {'edges': [{'name': edge.name, 'fact': edge.fact} for edge in edges]}
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
llm_response = await llm_client.generate_response(
|
|
|
|
prompt_library.dedupe_edges.edge_list(context)
|
|
|
|
)
|
|
|
|
unique_edges_data = llm_response.get('unique_edges', [])
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
end = time()
|
2024-08-22 17:24:59 -04:00
|
|
|
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
# Get full edge data
|
|
|
|
unique_edges = []
|
|
|
|
for edge_data in unique_edges_data:
|
|
|
|
fact = edge_data['fact']
|
|
|
|
unique_edges.append(edge_map[fact])
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
return unique_edges
|