improve deduping issue (#28)

* improve deduping issue

* fix comment

* commit format

* default embeddings

* update
This commit is contained in:
Preston Rasmussen 2024-08-23 12:17:15 -04:00 committed by GitHub
parent 9cc9883e66
commit a1e54881a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 199 additions and 186 deletions

View File

@ -5,66 +5,64 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
edge_list: PromptVersion
v1: PromptVersion
v2: PromptVersion
edge_list: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
edge_list: PromptFunction
v1: PromptFunction
v2: PromptFunction
edge_list: PromptFunction
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates relationship from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates relationship from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
Existing Edges:
Existing Facts:
{json.dumps(context['existing_edges'], indent=2)}
New Edges:
New Facts:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. start with the list of edges from New Edges
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
edge in the list
3. Respond with the resulting list of edges
If any facts in New Facts is a duplicate of a fact in Existing Facts,
do not return it in the list of unique facts.
Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
duplicate edges may have different names
1. The facts do not have to be completely identical to be duplicates,
they just need to have similar factual content
Respond with a JSON object in the following format:
{{
"new_edges": [
"unique_facts": [
{{
"fact": "one sentence description of the fact"
"uuid": "unique identifier of the fact"
}}
]
}}
""",
),
]
),
]
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates relationship from edge lists.',
),
Message(
role='user',
content=f"""
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates relationship from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
Existing Edges:
@ -94,44 +92,44 @@ def v2(context: dict[str, Any]) -> list[Message]:
]
}}
""",
),
]
),
]
def edge_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates edges from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, find all of the duplicates in a list of edges:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates edges from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, find all of the duplicates in a list of facts:
Edges:
Facts:
{json.dumps(context['edges'], indent=2)}
Task:
If any edge in Edges is a duplicate of another edge, return the fact of only one of the duplicate edges
If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
Guidelines:
1. Use both the name and fact of edges to determine if they are duplicates,
edges with the same name may not be duplicates
2. The final list should have only unique facts. If 3 edges are all duplicates of each other, only one of their
1. The facts do not have to be completely identical to be duplicates, they just need to have similar content
2. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
facts should be in the response
Respond with a JSON object in the following format:
{{
"unique_edges": [
"unique_facts": [
{{
"fact": "fact of a unique edge",
"uuid": "unique identifier of the fact",
"fact": "fact of a unique edge"
}}
]
}}
""",
),
]
),
]
versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list}

View File

@ -3,6 +3,7 @@ import typing
from datetime import datetime
from neo4j import AsyncDriver
from numpy import dot
from pydantic import BaseModel
from core.edges import Edge, EntityEdge, EpisodicEdge
@ -11,186 +12,198 @@ from core.nodes import EntityNode, EpisodicNode
from core.search.search_utils import get_relevant_edges, get_relevant_nodes
from core.utils import retrieve_episodes
from core.utils.maintenance.edge_operations import (
build_episodic_edges,
dedupe_edge_list,
dedupe_extracted_edges,
extract_edges,
build_episodic_edges,
dedupe_edge_list,
dedupe_extracted_edges,
extract_edges,
)
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from core.utils.maintenance.node_operations import (
dedupe_extracted_nodes,
dedupe_node_list,
extract_nodes,
dedupe_extracted_nodes,
dedupe_node_list,
extract_nodes,
)
CHUNK_SIZE = 10
CHUNK_SIZE = 15
class BulkEpisode(BaseModel):
name: str
content: str
source_description: str
episode_type: str
reference_time: datetime
name: str
content: str
source_description: str
episode_type: str
reference_time: datetime
async def retrieve_previous_episodes_bulk(
driver: AsyncDriver, episodes: list[EpisodicNode]
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await asyncio.gather(
*[
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
for episode in episodes
]
)
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
]
previous_episodes_list = await asyncio.gather(
*[
retrieve_episodes(driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN)
for episode in episodes
]
)
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
]
return episode_tuples
return episode_tuples
async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
extracted_nodes_bulk = await asyncio.gather(
*[
extract_nodes(llm_client, episode, previous_episodes)
for episode, previous_episodes in episode_tuples
]
)
extracted_nodes_bulk = await asyncio.gather(
*[
extract_nodes(llm_client, episode, previous_episodes)
for episode, previous_episodes in episode_tuples
]
)
episodes, previous_episodes_list = (
[episode[0] for episode in episode_tuples],
[episode[1] for episode in episode_tuples],
)
episodes, previous_episodes_list = (
[episode[0] for episode in episode_tuples],
[episode[1] for episode in episode_tuples],
)
extracted_edges_bulk = await asyncio.gather(
*[
extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i])
for i, episode in enumerate(episodes)
]
)
extracted_edges_bulk = await asyncio.gather(
*[
extract_edges(llm_client, episode, extracted_nodes_bulk[i], previous_episodes_list[i])
for i, episode in enumerate(episodes)
]
)
episodic_edges: list[EpisodicEdge] = []
for i, episode in enumerate(episodes):
episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at)
episodic_edges: list[EpisodicEdge] = []
for i, episode in enumerate(episodes):
episodic_edges += build_episodic_edges(extracted_nodes_bulk[i], episode, episode.created_at)
nodes: list[EntityNode] = []
for extracted_nodes in extracted_nodes_bulk:
nodes += extracted_nodes
nodes: list[EntityNode] = []
for extracted_nodes in extracted_nodes_bulk:
nodes += extracted_nodes
edges: list[EntityEdge] = []
for extracted_edges in extracted_edges_bulk:
edges += extracted_edges
edges: list[EntityEdge] = []
for extracted_edges in extracted_edges_bulk:
edges += extracted_edges
return nodes, edges, episodic_edges
return nodes, edges, episodic_edges
async def dedupe_nodes_bulk(
driver: AsyncDriver,
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
driver: AsyncDriver,
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
) -> tuple[list[EntityNode], dict[str, str]]:
# Compress nodes
nodes, uuid_map = node_name_match(extracted_nodes)
# Compress nodes
nodes, uuid_map = node_name_match(extracted_nodes)
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
existing_nodes = await get_relevant_nodes(compressed_nodes, driver)
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
llm_client, compressed_nodes, existing_nodes
)
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
llm_client, compressed_nodes, existing_nodes
)
compressed_map.update(partial_uuid_map)
compressed_map.update(partial_uuid_map)
return nodes, compressed_map
return nodes, compressed_map
async def dedupe_edges_bulk(
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
) -> list[EntityEdge]:
# Compress edges
compressed_edges = await compress_edges(llm_client, extracted_edges)
# Compress edges
compressed_edges = await compress_edges(llm_client, extracted_edges)
existing_edges = await get_relevant_edges(compressed_edges, driver)
existing_edges = await get_relevant_edges(compressed_edges, driver)
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
edges = await dedupe_extracted_edges(llm_client, compressed_edges, existing_edges)
return edges
return edges
def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map: dict[str, str] = {}
name_map: dict[str, EntityNode] = {}
for node in nodes:
if node.name in name_map:
uuid_map[node.uuid] = name_map[node.name].uuid
continue
uuid_map: dict[str, str] = {}
name_map: dict[str, EntityNode] = {}
for node in nodes:
if node.name in name_map:
uuid_map[node.uuid] = name_map[node.name].uuid
continue
name_map[node.name] = node
name_map[node.name] = node
return [node for node in name_map.values()], uuid_map
return [node for node in name_map.values()], uuid_map
async def compress_nodes(
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
) -> tuple[list[EntityNode], dict[str, str]]:
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
if len(nodes) == 0:
return nodes, uuid_map
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
anchor = nodes[0]
nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or []))
extended_map = dict(uuid_map)
compressed_nodes: list[EntityNode] = []
for node_chunk, uuid_map_chunk in results:
compressed_nodes += node_chunk
extended_map.update(uuid_map_chunk)
node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
# Check if we have removed all duplicates
if len(compressed_nodes) == len(nodes):
compressed_uuid_map = compress_uuid_map(extended_map)
return compressed_nodes, compressed_uuid_map
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
return await compress_nodes(llm_client, compressed_nodes, extended_map)
extended_map = dict(uuid_map)
compressed_nodes: list[EntityNode] = []
for node_chunk, uuid_map_chunk in results:
compressed_nodes += node_chunk
extended_map.update(uuid_map_chunk)
# Check if we have removed all duplicates
if len(compressed_nodes) == len(nodes):
compressed_uuid_map = compress_uuid_map(extended_map)
return compressed_nodes, compressed_uuid_map
return await compress_nodes(llm_client, compressed_nodes, extended_map)
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
if len(edges) == 0:
return edges
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
anchor = edges[0]
edges.sort(key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or []))
compressed_edges: list[EntityEdge] = []
for edge_chunk in results:
compressed_edges += edge_chunk
edge_chunks = [edges[i: i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
# Check if we have removed all duplicates
if len(compressed_edges) == len(edges):
return compressed_edges
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
return await compress_edges(llm_client, compressed_edges)
compressed_edges: list[EntityEdge] = []
for edge_chunk in results:
compressed_edges += edge_chunk
# Check if we have removed all duplicates
if len(compressed_edges) == len(edges):
return compressed_edges
return await compress_edges(llm_client, compressed_edges)
def compress_uuid_map(uuid_map: dict[str, str]) -> dict[str, str]:
# make sure all uuid values aren't mapped to other uuids
compressed_map = {}
for key, uuid in uuid_map.items():
curr_value = uuid
while curr_value in uuid_map:
curr_value = uuid_map[curr_value]
# make sure all uuid values aren't mapped to other uuids
compressed_map = {}
for key, uuid in uuid_map.items():
curr_value = uuid
while curr_value in uuid_map:
curr_value = uuid_map[curr_value]
compressed_map[key] = curr_value
return compressed_map
compressed_map[key] = curr_value
return compressed_map
E = typing.TypeVar('E', bound=Edge)
def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
for edge in edges:
source_uuid = edge.source_node_uuid
target_uuid = edge.target_node_uuid
edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
for edge in edges:
source_uuid = edge.source_node_uuid
target_uuid = edge.target_node_uuid
edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
return edges
return edges

View File

@ -94,27 +94,27 @@ async def dedupe_extracted_edges(
) -> list[EntityEdge]:
# Create edge map
edge_map = {}
for edge in existing_edges:
edge_map[edge.fact] = edge
for edge in extracted_edges:
if edge.fact in edge_map:
continue
edge_map[edge.fact] = edge
edge_map[edge.uuid] = edge
# Prepare context for LLM
context = {
'extracted_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
'existing_edges': [{'name': edge.name, 'fact': edge.fact} for edge in extracted_edges],
'extracted_edges': [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges
],
'existing_edges': [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges
],
}
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context))
new_edges_data = llm_response.get('new_edges', [])
logger.info(f'Extracted new edges: {new_edges_data}')
unique_edge_data = llm_response.get('unique_facts', [])
logger.info(f'Extracted unique edges: {unique_edge_data}')
# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data['fact']]
for unique_edge in unique_edge_data:
edge = edge_map[unique_edge['uuid']]
edges.append(edge)
return edges
@ -129,15 +129,15 @@ async def dedupe_edge_list(
# Create edge map
edge_map = {}
for edge in edges:
edge_map[edge.fact] = edge
edge_map[edge.uuid] = edge
# Prepare context for LLM
context = {'edges': [{'name': edge.name, 'fact': edge.fact} for edge in edges]}
context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.edge_list(context)
)
unique_edges_data = llm_response.get('unique_edges', [])
unique_edges_data = llm_response.get('unique_facts', [])
end = time()
logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ')
@ -145,7 +145,9 @@ async def dedupe_edge_list(
# Get full edge data
unique_edges = []
for edge_data in unique_edges_data:
fact = edge_data['fact']
unique_edges.append(edge_map[fact])
uuid = edge_data['uuid']
edge = edge_map[uuid]
edge.fact = edge_data['fact']
unique_edges.append(edge)
return unique_edges

View File

@ -62,10 +62,10 @@ async def main(use_bulk: bool = True):
episode_type='string',
reference_time=message.actual_timestamp,
)
for i, message in enumerate(messages[3:7])
for i, message in enumerate(messages[3:14])
]
await client.add_episode_bulk(episodes)
asyncio.run(main())
asyncio.run(main(True))