2024-08-23 13:01:33 -07:00
|
|
|
"""
|
|
|
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
|
2024-08-30 10:48:28 -04:00
|
|
|
import logging
|
2024-08-23 08:15:44 -07:00
|
|
|
import typing
|
2024-12-09 10:36:04 -08:00
|
|
|
from datetime import datetime
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
import numpy as np
|
|
|
|
|
from pydantic import BaseModel, Field
|
2025-02-13 12:17:52 -05:00
|
|
|
from typing_extensions import Any
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-06-13 12:06:57 -04:00
|
|
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
2025-07-10 12:14:49 -04:00
|
|
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
2025-05-08 15:34:13 -04:00
|
|
|
from graphiti_core.embedder import EmbedderClient
|
2025-06-13 12:06:57 -04:00
|
|
|
from graphiti_core.graph_queries import (
|
|
|
|
|
get_entity_edge_save_bulk_query,
|
|
|
|
|
get_entity_node_save_bulk_query,
|
|
|
|
|
)
|
2025-04-26 00:24:23 -04:00
|
|
|
from graphiti_core.graphiti_types import GraphitiClients
|
2025-07-10 12:14:49 -04:00
|
|
|
from graphiti_core.helpers import DEFAULT_DATABASE, normalize_l2, semaphore_gather
|
2024-10-31 12:31:37 -04:00
|
|
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
|
|
|
EPISODIC_EDGE_SAVE_BULK,
|
|
|
|
|
)
|
|
|
|
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
|
|
|
EPISODIC_NODE_SAVE_BULK,
|
|
|
|
|
)
|
2025-07-10 12:14:49 -04:00
|
|
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
2024-08-25 10:07:50 -07:00
|
|
|
from graphiti_core.utils.maintenance.edge_operations import (
|
2024-08-23 12:17:15 -04:00
|
|
|
extract_edges,
|
2025-07-10 12:14:49 -04:00
|
|
|
resolve_extracted_edge,
|
2024-08-21 12:03:32 -04:00
|
|
|
)
|
2024-12-09 10:36:04 -08:00
|
|
|
from graphiti_core.utils.maintenance.graph_data_operations import (
|
|
|
|
|
EPISODE_WINDOW_LEN,
|
|
|
|
|
retrieve_episodes,
|
|
|
|
|
)
|
2024-08-25 10:07:50 -07:00
|
|
|
from graphiti_core.utils.maintenance.node_operations import (
|
2024-08-23 12:17:15 -04:00
|
|
|
extract_nodes,
|
2025-07-10 12:14:49 -04:00
|
|
|
resolve_extracted_nodes,
|
2024-08-21 12:03:32 -04:00
|
|
|
)
|
|
|
|
|
|
2024-08-30 10:48:28 -04:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
CHUNK_SIZE = 10
|
2024-08-21 12:03:32 -04:00
|
|
|
|
|
|
|
|
|
2024-08-26 10:30:22 -04:00
|
|
|
class RawEpisode(BaseModel):
|
2024-08-23 12:17:15 -04:00
|
|
|
name: str
|
2025-07-10 12:14:49 -04:00
|
|
|
uuid: str | None = Field(default=None)
|
2024-08-23 12:17:15 -04:00
|
|
|
content: str
|
|
|
|
|
source_description: str
|
2024-08-26 10:30:22 -04:00
|
|
|
source: EpisodeType
|
2024-08-23 12:17:15 -04:00
|
|
|
reference_time: datetime
|
2024-08-21 12:03:32 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def retrieve_previous_episodes_bulk(
|
2025-06-13 12:06:57 -04:00
|
|
|
driver: GraphDriver, episodes: list[EpisodicNode]
|
2024-08-21 12:03:32 -04:00
|
|
|
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
2024-12-17 13:08:18 -05:00
|
|
|
previous_episodes_list = await semaphore_gather(
|
2024-08-23 12:17:15 -04:00
|
|
|
*[
|
2024-09-06 12:33:42 -04:00
|
|
|
retrieve_episodes(
|
|
|
|
|
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
|
|
|
|
|
)
|
2024-08-23 12:17:15 -04:00
|
|
|
for episode in episodes
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
|
|
|
|
|
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
|
|
|
|
|
]
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-23 12:17:15 -04:00
|
|
|
return episode_tuples
|
2024-08-21 12:03:32 -04:00
|
|
|
|
|
|
|
|
|
2024-10-31 12:31:37 -04:00
|
|
|
async def add_nodes_and_edges_bulk(
|
2025-06-13 12:06:57 -04:00
|
|
|
driver: GraphDriver,
|
2024-10-31 12:31:37 -04:00
|
|
|
episodic_nodes: list[EpisodicNode],
|
|
|
|
|
episodic_edges: list[EpisodicEdge],
|
|
|
|
|
entity_nodes: list[EntityNode],
|
|
|
|
|
entity_edges: list[EntityEdge],
|
2025-05-08 15:34:13 -04:00
|
|
|
embedder: EmbedderClient,
|
2024-10-31 12:31:37 -04:00
|
|
|
):
|
2025-06-13 12:06:57 -04:00
|
|
|
session = driver.session(database=DEFAULT_DATABASE)
|
|
|
|
|
try:
|
2024-10-31 12:31:37 -04:00
|
|
|
await session.execute_write(
|
2025-05-08 15:34:13 -04:00
|
|
|
add_nodes_and_edges_bulk_tx,
|
|
|
|
|
episodic_nodes,
|
|
|
|
|
episodic_edges,
|
|
|
|
|
entity_nodes,
|
|
|
|
|
entity_edges,
|
|
|
|
|
embedder,
|
2025-06-13 12:06:57 -04:00
|
|
|
driver=driver,
|
2024-10-31 12:31:37 -04:00
|
|
|
)
|
2025-06-13 12:06:57 -04:00
|
|
|
finally:
|
|
|
|
|
await session.close()
|
2024-10-31 12:31:37 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def add_nodes_and_edges_bulk_tx(
|
2025-06-13 12:06:57 -04:00
|
|
|
tx: GraphDriverSession,
|
2024-10-31 12:31:37 -04:00
|
|
|
episodic_nodes: list[EpisodicNode],
|
|
|
|
|
episodic_edges: list[EpisodicEdge],
|
|
|
|
|
entity_nodes: list[EntityNode],
|
|
|
|
|
entity_edges: list[EntityEdge],
|
2025-05-08 15:34:13 -04:00
|
|
|
embedder: EmbedderClient,
|
2025-06-13 12:06:57 -04:00
|
|
|
driver: GraphDriver,
|
2024-10-31 12:31:37 -04:00
|
|
|
):
|
|
|
|
|
episodes = [dict(episode) for episode in episodic_nodes]
|
|
|
|
|
for episode in episodes:
|
|
|
|
|
episode['source'] = str(episode['source'].value)
|
2025-02-13 12:17:52 -05:00
|
|
|
nodes: list[dict[str, Any]] = []
|
|
|
|
|
for node in entity_nodes:
|
2025-05-08 15:34:13 -04:00
|
|
|
if node.name_embedding is None:
|
|
|
|
|
await node.generate_name_embedding(embedder)
|
2025-02-13 12:17:52 -05:00
|
|
|
entity_data: dict[str, Any] = {
|
|
|
|
|
'uuid': node.uuid,
|
|
|
|
|
'name': node.name,
|
|
|
|
|
'name_embedding': node.name_embedding,
|
|
|
|
|
'group_id': node.group_id,
|
|
|
|
|
'summary': node.summary,
|
|
|
|
|
'created_at': node.created_at,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
entity_data.update(node.attributes or {})
|
|
|
|
|
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
|
|
|
|
nodes.append(entity_data)
|
|
|
|
|
|
2025-05-19 13:30:56 -04:00
|
|
|
edges: list[dict[str, Any]] = []
|
2025-05-08 15:34:13 -04:00
|
|
|
for edge in entity_edges:
|
|
|
|
|
if edge.fact_embedding is None:
|
2025-05-08 18:25:22 -04:00
|
|
|
await edge.generate_embedding(embedder)
|
2025-05-19 13:30:56 -04:00
|
|
|
edge_data: dict[str, Any] = {
|
|
|
|
|
'uuid': edge.uuid,
|
|
|
|
|
'source_node_uuid': edge.source_node_uuid,
|
|
|
|
|
'target_node_uuid': edge.target_node_uuid,
|
|
|
|
|
'name': edge.name,
|
|
|
|
|
'fact': edge.fact,
|
|
|
|
|
'fact_embedding': edge.fact_embedding,
|
|
|
|
|
'group_id': edge.group_id,
|
|
|
|
|
'episodes': edge.episodes,
|
|
|
|
|
'created_at': edge.created_at,
|
|
|
|
|
'expired_at': edge.expired_at,
|
|
|
|
|
'valid_at': edge.valid_at,
|
|
|
|
|
'invalid_at': edge.invalid_at,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
edge_data.update(edge.attributes or {})
|
|
|
|
|
edges.append(edge_data)
|
2025-05-08 15:34:13 -04:00
|
|
|
|
2024-10-31 12:31:37 -04:00
|
|
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
2025-06-13 12:06:57 -04:00
|
|
|
entity_node_save_bulk = get_entity_node_save_bulk_query(nodes, driver.provider)
|
|
|
|
|
await tx.run(entity_node_save_bulk, nodes=nodes)
|
2025-04-26 00:24:23 -04:00
|
|
|
await tx.run(
|
|
|
|
|
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
|
|
|
|
)
|
2025-06-13 12:06:57 -04:00
|
|
|
entity_edge_save_bulk = get_entity_edge_save_bulk_query(driver.provider)
|
|
|
|
|
await tx.run(entity_edge_save_bulk, entity_edges=edges)
|
2024-10-31 12:31:37 -04:00
|
|
|
|
|
|
|
|
|
2024-08-21 12:03:32 -04:00
|
|
|
async def extract_nodes_and_edges_bulk(
|
2025-06-26 20:54:43 -07:00
|
|
|
clients: GraphitiClients,
|
|
|
|
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
2025-07-10 12:14:49 -04:00
|
|
|
edge_type_map: dict[tuple[str, str], list[str]],
|
2025-06-26 20:54:43 -07:00
|
|
|
entity_types: dict[str, BaseModel] | None = None,
|
|
|
|
|
excluded_entity_types: list[str] | None = None,
|
2025-07-10 12:14:49 -04:00
|
|
|
edge_types: dict[str, BaseModel] | None = None,
|
|
|
|
|
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
|
|
|
|
|
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
|
2024-08-23 12:17:15 -04:00
|
|
|
*[
|
2025-06-26 20:54:43 -07:00
|
|
|
extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
|
2024-08-23 12:17:15 -04:00
|
|
|
for episode, previous_episodes in episode_tuples
|
|
|
|
|
]
|
|
|
|
|
)
|
2024-08-22 13:06:42 -07:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
|
2024-08-23 12:17:15 -04:00
|
|
|
*[
|
2024-09-06 12:33:42 -04:00
|
|
|
extract_edges(
|
2025-04-26 00:24:23 -04:00
|
|
|
clients,
|
2024-09-06 12:33:42 -04:00
|
|
|
episode,
|
|
|
|
|
extracted_nodes_bulk[i],
|
2025-07-10 12:14:49 -04:00
|
|
|
previous_episodes,
|
|
|
|
|
edge_type_map=edge_type_map,
|
|
|
|
|
group_id=episode.group_id,
|
|
|
|
|
edge_types=edge_types,
|
2024-09-06 12:33:42 -04:00
|
|
|
)
|
2025-07-10 12:14:49 -04:00
|
|
|
for i, (episode, previous_episodes) in enumerate(episode_tuples)
|
2024-08-23 12:17:15 -04:00
|
|
|
]
|
|
|
|
|
)
|
2024-08-22 13:06:42 -07:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
return extracted_nodes_bulk, extracted_edges_bulk
|
2024-08-21 12:03:32 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def dedupe_nodes_bulk(
|
2025-07-10 12:14:49 -04:00
|
|
|
clients: GraphitiClients,
|
|
|
|
|
extracted_nodes: list[list[EntityNode]],
|
|
|
|
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
|
|
|
entity_types: dict[str, BaseModel] | None = None,
|
|
|
|
|
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
|
|
|
|
embedder = clients.embedder
|
|
|
|
|
min_score = 0.8
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
# generate embeddings
|
|
|
|
|
await semaphore_gather(
|
|
|
|
|
*[create_entity_node_embeddings(embedder, nodes) for nodes in extracted_nodes]
|
2024-08-30 10:48:28 -04:00
|
|
|
)
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
# Find similar results
|
|
|
|
|
dedupe_tuples: list[tuple[list[EntityNode], list[EntityNode]]] = []
|
|
|
|
|
for i, nodes_i in enumerate(extracted_nodes):
|
|
|
|
|
existing_nodes: list[EntityNode] = []
|
|
|
|
|
for j, nodes_j in enumerate(extracted_nodes):
|
|
|
|
|
if i == j:
|
|
|
|
|
continue
|
|
|
|
|
existing_nodes += nodes_j
|
|
|
|
|
|
|
|
|
|
candidates_i: list[EntityNode] = []
|
|
|
|
|
for node in nodes_i:
|
|
|
|
|
for existing_node in existing_nodes:
|
|
|
|
|
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
|
|
|
|
# This approach will cast a wider net than BM25, which is ideal for this use case
|
|
|
|
|
node_words = set(node.name.lower().split())
|
|
|
|
|
existing_node_words = set(existing_node.name.lower().split())
|
|
|
|
|
has_overlap = not node_words.isdisjoint(existing_node_words)
|
|
|
|
|
if has_overlap:
|
|
|
|
|
candidates_i.append(existing_node)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Check for semantic similarity even if there is no overlap
|
|
|
|
|
similarity = np.dot(
|
|
|
|
|
normalize_l2(node.name_embedding or []),
|
|
|
|
|
normalize_l2(existing_node.name_embedding or []),
|
|
|
|
|
)
|
|
|
|
|
if similarity >= min_score:
|
|
|
|
|
candidates_i.append(existing_node)
|
|
|
|
|
|
|
|
|
|
dedupe_tuples.append((nodes_i, candidates_i))
|
|
|
|
|
|
|
|
|
|
# Determine Node Resolutions
|
|
|
|
|
bulk_node_resolutions: list[
|
|
|
|
|
tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]
|
|
|
|
|
] = await semaphore_gather(
|
|
|
|
|
*[
|
|
|
|
|
resolve_extracted_nodes(
|
|
|
|
|
clients,
|
|
|
|
|
dedupe_tuple[0],
|
|
|
|
|
episode_tuples[i][0],
|
|
|
|
|
episode_tuples[i][1],
|
|
|
|
|
entity_types,
|
|
|
|
|
existing_nodes_override=dedupe_tuples[i][1],
|
|
|
|
|
)
|
|
|
|
|
for i, dedupe_tuple in enumerate(dedupe_tuples)
|
|
|
|
|
]
|
2024-08-30 10:48:28 -04:00
|
|
|
)
|
|
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
# Collect all duplicate pairs sorted by uuid
|
|
|
|
|
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = []
|
|
|
|
|
for _, _, duplicates in bulk_node_resolutions:
|
|
|
|
|
for duplicate in duplicates:
|
|
|
|
|
n, m = duplicate
|
|
|
|
|
if n.uuid < m.uuid:
|
|
|
|
|
duplicate_pairs.append((n, m))
|
|
|
|
|
else:
|
|
|
|
|
duplicate_pairs.append((m, n))
|
|
|
|
|
|
|
|
|
|
# Build full deduplication map
|
|
|
|
|
duplicate_map: dict[str, str] = {}
|
|
|
|
|
for value, key in duplicate_pairs:
|
|
|
|
|
if key.uuid in duplicate_map:
|
|
|
|
|
existing_value = duplicate_map[key.uuid]
|
|
|
|
|
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
|
|
|
|
else:
|
|
|
|
|
duplicate_map[key.uuid] = value.uuid
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
|
|
|
|
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
node_uuid_map: dict[str, EntityNode] = {
|
|
|
|
|
node.uuid: node for nodes in extracted_nodes for node in nodes
|
|
|
|
|
}
|
2024-08-30 10:48:28 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
|
|
|
|
for i, nodes in enumerate(extracted_nodes):
|
|
|
|
|
episode = episode_tuples[i][0]
|
2024-08-30 10:48:28 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
nodes_by_episode[episode.uuid] = [
|
|
|
|
|
node_uuid_map[compressed_map.get(node.uuid, node.uuid)] for node in nodes
|
|
|
|
|
]
|
2024-08-30 10:48:28 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
return nodes_by_episode, compressed_map
|
2024-08-30 10:48:28 -04:00
|
|
|
|
|
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
async def dedupe_edges_bulk(
|
|
|
|
|
clients: GraphitiClients,
|
|
|
|
|
extracted_edges: list[list[EntityEdge]],
|
|
|
|
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
|
|
|
_entities: list[EntityNode],
|
|
|
|
|
edge_types: dict[str, BaseModel],
|
|
|
|
|
_edge_type_map: dict[tuple[str, str], list[str]],
|
|
|
|
|
) -> dict[str, list[EntityEdge]]:
|
|
|
|
|
embedder = clients.embedder
|
|
|
|
|
min_score = 0.6
|
|
|
|
|
|
|
|
|
|
# generate embeddings
|
|
|
|
|
await semaphore_gather(
|
|
|
|
|
*[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
|
|
|
|
|
)
|
2024-08-30 10:48:28 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
# Find similar results
|
|
|
|
|
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
|
|
|
|
|
for i, edges_i in enumerate(extracted_edges):
|
|
|
|
|
existing_edges: list[EntityEdge] = []
|
|
|
|
|
for j, edges_j in enumerate(extracted_edges):
|
|
|
|
|
if i == j:
|
|
|
|
|
continue
|
|
|
|
|
existing_edges += edges_j
|
|
|
|
|
|
|
|
|
|
for edge in edges_i:
|
|
|
|
|
candidates: list[EntityEdge] = []
|
|
|
|
|
for existing_edge in existing_edges:
|
|
|
|
|
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
|
|
|
|
# This approach will cast a wider net than BM25, which is ideal for this use case
|
|
|
|
|
edge_words = set(edge.fact.lower().split())
|
|
|
|
|
existing_edge_words = set(existing_edge.fact.lower().split())
|
|
|
|
|
has_overlap = not edge_words.isdisjoint(existing_edge_words)
|
|
|
|
|
if has_overlap:
|
|
|
|
|
candidates.append(existing_edge)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Check for semantic similarity even if there is no overlap
|
|
|
|
|
similarity = np.dot(
|
|
|
|
|
normalize_l2(edge.fact_embedding or []),
|
|
|
|
|
normalize_l2(existing_edge.fact_embedding or []),
|
|
|
|
|
)
|
|
|
|
|
if similarity >= min_score:
|
|
|
|
|
candidates.append(existing_edge)
|
|
|
|
|
|
|
|
|
|
dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
|
|
|
|
|
|
|
|
|
|
bulk_edge_resolutions: list[
|
|
|
|
|
tuple[EntityEdge, EntityEdge, list[EntityEdge]]
|
|
|
|
|
] = await semaphore_gather(
|
|
|
|
|
*[
|
|
|
|
|
resolve_extracted_edge(
|
|
|
|
|
clients.llm_client, edge, candidates, candidates, episode, edge_types
|
|
|
|
|
)
|
|
|
|
|
for episode, edge, candidates in dedupe_tuples
|
|
|
|
|
]
|
|
|
|
|
)
|
2024-08-30 10:48:28 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
duplicate_pairs: list[tuple[EntityEdge, EntityEdge]] = []
|
|
|
|
|
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
|
|
|
|
|
episode, edge, candidates = dedupe_tuples[i]
|
|
|
|
|
for duplicate in duplicates:
|
|
|
|
|
if edge.uuid < duplicate.uuid:
|
|
|
|
|
duplicate_pairs.append((edge, duplicate))
|
|
|
|
|
else:
|
|
|
|
|
duplicate_pairs.append((duplicate, edge))
|
|
|
|
|
|
|
|
|
|
# Build full deduplication map
|
|
|
|
|
duplicate_map: dict[str, str] = {}
|
|
|
|
|
for value, key in duplicate_pairs:
|
|
|
|
|
if key.uuid in duplicate_map:
|
|
|
|
|
existing_value = duplicate_map[key.uuid]
|
|
|
|
|
duplicate_map[key.uuid] = value.uuid if value.uuid < existing_value else existing_value
|
|
|
|
|
else:
|
|
|
|
|
duplicate_map[key.uuid] = value.uuid
|
2024-08-30 10:48:28 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
|
|
|
|
compressed_map: dict[str, str] = compress_uuid_map(duplicate_map)
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
edge_uuid_map: dict[str, EntityEdge] = {
|
|
|
|
|
edge.uuid: edge for edges in extracted_edges for edge in edges
|
|
|
|
|
}
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
edges_by_episode: dict[str, list[EntityEdge]] = {}
|
|
|
|
|
for i, edges in enumerate(extracted_edges):
|
|
|
|
|
episode = episode_tuples[i][0]
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
edges_by_episode[episode.uuid] = [
|
|
|
|
|
edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
|
|
|
|
|
]
|
2024-08-23 12:17:15 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
return edges_by_episode
|
2024-08-23 12:17:15 -04:00
|
|
|
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
|
|
|
|
|
compressed_map = {}
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
def find_min_uuid(start: str) -> str:
|
|
|
|
|
path = []
|
|
|
|
|
visited = set()
|
|
|
|
|
curr = start
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
while curr in uuid_map and curr not in visited:
|
|
|
|
|
visited.add(curr)
|
|
|
|
|
path.append(curr)
|
|
|
|
|
curr = uuid_map[curr]
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
# Also include the last resolved value (could be outside the map)
|
|
|
|
|
path.append(curr)
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
# Resolve to lex smallest UUID in the path
|
|
|
|
|
min_uuid = min(path)
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
# Assign all UUIDs in the path to the min_uuid
|
|
|
|
|
for node in path:
|
|
|
|
|
compressed_map[node] = min_uuid
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
return min_uuid
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2025-07-10 12:14:49 -04:00
|
|
|
for key in uuid_map:
|
|
|
|
|
if key not in compressed_map:
|
|
|
|
|
find_min_uuid(key)
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-23 12:17:15 -04:00
|
|
|
return compressed_map
|
2024-08-21 12:03:32 -04:00
|
|
|
|
|
|
|
|
|
2024-08-23 08:15:44 -07:00
|
|
|
E = typing.TypeVar('E', bound=Edge)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
2024-08-23 12:17:15 -04:00
|
|
|
for edge in edges:
|
|
|
|
|
source_uuid = edge.source_node_uuid
|
|
|
|
|
target_uuid = edge.target_node_uuid
|
|
|
|
|
edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
|
|
|
|
|
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
2024-08-22 13:06:42 -07:00
|
|
|
|
2024-08-23 12:17:15 -04:00
|
|
|
return edges
|