2024-09-16 14:03:05 -04:00
|
|
|
"""
|
|
|
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
import asyncio
|
|
|
|
|
import logging
|
2024-08-26 10:30:22 -04:00
|
|
|
import re
|
2024-08-22 14:26:26 -04:00
|
|
|
from collections import defaultdict
|
2024-08-21 12:03:32 -04:00
|
|
|
from time import time
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-09-04 10:05:45 -04:00
|
|
|
from neo4j import AsyncDriver, Query
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-09-06 12:33:42 -04:00
|
|
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
2024-09-16 14:03:05 -04:00
|
|
|
from graphiti_core.nodes import (
|
|
|
|
|
CommunityNode,
|
|
|
|
|
EntityNode,
|
|
|
|
|
EpisodicNode,
|
|
|
|
|
get_community_node_from_record,
|
|
|
|
|
get_entity_node_from_record,
|
|
|
|
|
)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2024-08-21 12:03:32 -04:00
|
|
|
RELEVANT_SCHEMA_LIMIT = 3
|
|
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-09-18 15:44:28 -04:00
|
|
|
async def get_mentioned_nodes(
|
|
|
|
|
driver: AsyncDriver, episodes: list[EpisodicNode]
|
|
|
|
|
) -> list[EntityNode]:
|
2024-08-23 14:18:45 -04:00
|
|
|
episode_uuids = [episode.uuid for episode in episodes]
|
|
|
|
|
records, _, _ = await driver.execute_query(
|
|
|
|
|
"""
|
2024-08-22 14:26:26 -04:00
|
|
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
|
|
|
|
RETURN DISTINCT
|
|
|
|
|
n.uuid As uuid,
|
2024-09-06 12:33:42 -04:00
|
|
|
n.group_id AS group_id,
|
2024-09-05 14:09:19 -04:00
|
|
|
n.name AS name,
|
|
|
|
|
n.name_embedding AS name_embedding
|
2024-08-22 14:26:26 -04:00
|
|
|
n.created_at AS created_at,
|
|
|
|
|
n.summary AS summary
|
|
|
|
|
""",
|
2024-08-23 14:18:45 -04:00
|
|
|
uuids=episode_uuids,
|
|
|
|
|
)
|
2024-08-22 14:26:26 -04:00
|
|
|
|
2024-09-06 12:33:42 -04:00
|
|
|
nodes = [get_entity_node_from_record(record) for record in records]
|
2024-08-22 14:26:26 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
return nodes
|
2024-08-22 14:26:26 -04:00
|
|
|
|
|
|
|
|
|
2024-09-18 15:44:28 -04:00
|
|
|
async def get_communities_by_nodes(
|
|
|
|
|
driver: AsyncDriver, nodes: list[EntityNode]
|
|
|
|
|
) -> list[CommunityNode]:
|
|
|
|
|
node_uuids = [node.uuid for node in nodes]
|
|
|
|
|
records, _, _ = await driver.execute_query(
|
|
|
|
|
"""
|
|
|
|
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
|
|
|
|
RETURN DISTINCT
|
|
|
|
|
c.uuid As uuid,
|
|
|
|
|
c.group_id AS group_id,
|
|
|
|
|
c.name AS name,
|
|
|
|
|
c.name_embedding AS name_embedding
|
|
|
|
|
c.created_at AS created_at,
|
|
|
|
|
c.summary AS summary
|
|
|
|
|
""",
|
|
|
|
|
uuids=node_uuids,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
communities = [get_community_node_from_record(record) for record in records]
|
|
|
|
|
|
|
|
|
|
return communities
|
|
|
|
|
|
|
|
|
|
|
2024-09-16 14:03:05 -04:00
|
|
|
async def edge_fulltext_search(
|
|
|
|
|
driver: AsyncDriver,
|
|
|
|
|
query: str,
|
|
|
|
|
source_node_uuid: str | None,
|
|
|
|
|
target_node_uuid: str | None,
|
2024-09-24 15:55:30 -04:00
|
|
|
group_ids: list[str] | None = None,
|
2024-09-16 14:03:05 -04:00
|
|
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
|
|
|
) -> list[EntityEdge]:
|
|
|
|
|
# fulltext search over facts
|
|
|
|
|
cypher_query = Query("""
|
|
|
|
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
|
|
|
YIELD relationship AS rel, score
|
|
|
|
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
2024-09-24 15:55:30 -04:00
|
|
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
2024-09-16 14:03:05 -04:00
|
|
|
RETURN
|
|
|
|
|
r.uuid AS uuid,
|
|
|
|
|
r.group_id AS group_id,
|
|
|
|
|
n.uuid AS source_node_uuid,
|
|
|
|
|
m.uuid AS target_node_uuid,
|
|
|
|
|
r.created_at AS created_at,
|
|
|
|
|
r.name AS name,
|
|
|
|
|
r.fact AS fact,
|
|
|
|
|
r.fact_embedding AS fact_embedding,
|
|
|
|
|
r.episodes AS episodes,
|
|
|
|
|
r.expired_at AS expired_at,
|
|
|
|
|
r.valid_at AS valid_at,
|
|
|
|
|
r.invalid_at AS invalid_at
|
|
|
|
|
ORDER BY score DESC LIMIT $limit
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
if source_node_uuid is None and target_node_uuid is None:
|
|
|
|
|
cypher_query = Query("""
|
|
|
|
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
|
|
|
YIELD relationship AS rel, score
|
|
|
|
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
2024-09-24 15:55:30 -04:00
|
|
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
2024-09-16 14:03:05 -04:00
|
|
|
RETURN
|
|
|
|
|
r.uuid AS uuid,
|
|
|
|
|
r.group_id AS group_id,
|
|
|
|
|
n.uuid AS source_node_uuid,
|
|
|
|
|
m.uuid AS target_node_uuid,
|
|
|
|
|
r.created_at AS created_at,
|
|
|
|
|
r.name AS name,
|
|
|
|
|
r.fact AS fact,
|
|
|
|
|
r.fact_embedding AS fact_embedding,
|
|
|
|
|
r.episodes AS episodes,
|
|
|
|
|
r.expired_at AS expired_at,
|
|
|
|
|
r.valid_at AS valid_at,
|
|
|
|
|
r.invalid_at AS invalid_at
|
|
|
|
|
ORDER BY score DESC LIMIT $limit
|
|
|
|
|
""")
|
|
|
|
|
elif source_node_uuid is None:
|
|
|
|
|
cypher_query = Query("""
|
|
|
|
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
|
|
|
YIELD relationship AS rel, score
|
|
|
|
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
2024-09-24 15:55:30 -04:00
|
|
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
2024-09-16 14:03:05 -04:00
|
|
|
RETURN
|
|
|
|
|
r.uuid AS uuid,
|
|
|
|
|
r.group_id AS group_id,
|
|
|
|
|
n.uuid AS source_node_uuid,
|
|
|
|
|
m.uuid AS target_node_uuid,
|
|
|
|
|
r.created_at AS created_at,
|
|
|
|
|
r.name AS name,
|
|
|
|
|
r.fact AS fact,
|
|
|
|
|
r.fact_embedding AS fact_embedding,
|
|
|
|
|
r.episodes AS episodes,
|
|
|
|
|
r.expired_at AS expired_at,
|
|
|
|
|
r.valid_at AS valid_at,
|
|
|
|
|
r.invalid_at AS invalid_at
|
|
|
|
|
ORDER BY score DESC LIMIT $limit
|
|
|
|
|
""")
|
|
|
|
|
elif target_node_uuid is None:
|
|
|
|
|
cypher_query = Query("""
|
|
|
|
|
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
|
|
|
YIELD relationship AS rel, score
|
|
|
|
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
2024-09-24 15:55:30 -04:00
|
|
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
2024-09-16 14:03:05 -04:00
|
|
|
RETURN
|
|
|
|
|
r.uuid AS uuid,
|
|
|
|
|
r.group_id AS group_id,
|
|
|
|
|
n.uuid AS source_node_uuid,
|
|
|
|
|
m.uuid AS target_node_uuid,
|
|
|
|
|
r.created_at AS created_at,
|
|
|
|
|
r.name AS name,
|
|
|
|
|
r.fact AS fact,
|
|
|
|
|
r.fact_embedding AS fact_embedding,
|
|
|
|
|
r.episodes AS episodes,
|
|
|
|
|
r.expired_at AS expired_at,
|
|
|
|
|
r.valid_at AS valid_at,
|
|
|
|
|
r.invalid_at AS invalid_at
|
|
|
|
|
ORDER BY score DESC LIMIT $limit
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
|
|
|
|
|
|
|
|
records, _, _ = await driver.execute_query(
|
|
|
|
|
cypher_query,
|
|
|
|
|
query=fuzzy_query,
|
|
|
|
|
source_uuid=source_node_uuid,
|
|
|
|
|
target_uuid=target_node_uuid,
|
|
|
|
|
group_ids=group_ids,
|
|
|
|
|
limit=limit,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
|
|
|
|
|
|
|
|
return edges
|
|
|
|
|
|
|
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
async def edge_similarity_search(
|
2024-09-05 12:05:44 -04:00
|
|
|
driver: AsyncDriver,
|
|
|
|
|
search_vector: list[float],
|
|
|
|
|
source_node_uuid: str | None,
|
|
|
|
|
target_node_uuid: str | None,
|
2024-09-24 15:55:30 -04:00
|
|
|
group_ids: list[str] | None = None,
|
2024-09-05 12:05:44 -04:00
|
|
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
2024-08-18 13:22:31 -04:00
|
|
|
) -> list[EntityEdge]:
|
2024-08-23 14:18:45 -04:00
|
|
|
# vector similarity search over embedded facts
|
2024-09-04 10:05:45 -04:00
|
|
|
query = Query("""
|
2024-08-27 16:18:01 -04:00
|
|
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
2024-09-03 13:25:52 -04:00
|
|
|
YIELD relationship AS rel, score
|
|
|
|
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
2024-09-24 15:55:30 -04:00
|
|
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
2024-08-18 13:22:31 -04:00
|
|
|
RETURN
|
|
|
|
|
r.uuid AS uuid,
|
2024-09-06 12:33:42 -04:00
|
|
|
r.group_id AS group_id,
|
2024-08-18 13:22:31 -04:00
|
|
|
n.uuid AS source_node_uuid,
|
|
|
|
|
m.uuid AS target_node_uuid,
|
|
|
|
|
r.created_at AS created_at,
|
|
|
|
|
r.name AS name,
|
|
|
|
|
r.fact AS fact,
|
|
|
|
|
r.fact_embedding AS fact_embedding,
|
|
|
|
|
r.episodes AS episodes,
|
|
|
|
|
r.expired_at AS expired_at,
|
|
|
|
|
r.valid_at AS valid_at,
|
|
|
|
|
r.invalid_at AS invalid_at
|
2024-08-27 16:18:01 -04:00
|
|
|
ORDER BY score DESC
|
2024-09-04 10:05:45 -04:00
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
if source_node_uuid is None and target_node_uuid is None:
|
|
|
|
|
query = Query("""
|
|
|
|
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
|
|
|
YIELD relationship AS rel, score
|
|
|
|
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
2024-09-24 15:55:30 -04:00
|
|
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
2024-09-04 10:05:45 -04:00
|
|
|
RETURN
|
|
|
|
|
r.uuid AS uuid,
|
2024-09-06 12:33:42 -04:00
|
|
|
r.group_id AS group_id,
|
2024-09-04 10:05:45 -04:00
|
|
|
n.uuid AS source_node_uuid,
|
|
|
|
|
m.uuid AS target_node_uuid,
|
|
|
|
|
r.created_at AS created_at,
|
|
|
|
|
r.name AS name,
|
|
|
|
|
r.fact AS fact,
|
|
|
|
|
r.fact_embedding AS fact_embedding,
|
|
|
|
|
r.episodes AS episodes,
|
|
|
|
|
r.expired_at AS expired_at,
|
|
|
|
|
r.valid_at AS valid_at,
|
|
|
|
|
r.invalid_at AS invalid_at
|
|
|
|
|
ORDER BY score DESC
|
|
|
|
|
""")
|
|
|
|
|
elif source_node_uuid is None:
|
|
|
|
|
query = Query("""
|
|
|
|
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
|
|
|
YIELD relationship AS rel, score
|
|
|
|
|
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
2024-09-24 15:55:30 -04:00
|
|
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
2024-09-04 10:05:45 -04:00
|
|
|
RETURN
|
|
|
|
|
r.uuid AS uuid,
|
2024-09-06 12:33:42 -04:00
|
|
|
r.group_id AS group_id,
|
2024-09-04 10:05:45 -04:00
|
|
|
n.uuid AS source_node_uuid,
|
|
|
|
|
m.uuid AS target_node_uuid,
|
|
|
|
|
r.created_at AS created_at,
|
|
|
|
|
r.name AS name,
|
|
|
|
|
r.fact AS fact,
|
|
|
|
|
r.fact_embedding AS fact_embedding,
|
|
|
|
|
r.episodes AS episodes,
|
|
|
|
|
r.expired_at AS expired_at,
|
|
|
|
|
r.valid_at AS valid_at,
|
|
|
|
|
r.invalid_at AS invalid_at
|
|
|
|
|
ORDER BY score DESC
|
|
|
|
|
""")
|
|
|
|
|
elif target_node_uuid is None:
|
|
|
|
|
query = Query("""
|
|
|
|
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
|
|
|
YIELD relationship AS rel, score
|
|
|
|
|
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
2024-09-24 15:55:30 -04:00
|
|
|
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
2024-09-04 10:05:45 -04:00
|
|
|
RETURN
|
|
|
|
|
r.uuid AS uuid,
|
2024-09-06 12:33:42 -04:00
|
|
|
r.group_id AS group_id,
|
2024-09-04 10:05:45 -04:00
|
|
|
n.uuid AS source_node_uuid,
|
|
|
|
|
m.uuid AS target_node_uuid,
|
|
|
|
|
r.created_at AS created_at,
|
|
|
|
|
r.name AS name,
|
|
|
|
|
r.fact AS fact,
|
|
|
|
|
r.fact_embedding AS fact_embedding,
|
|
|
|
|
r.episodes AS episodes,
|
|
|
|
|
r.expired_at AS expired_at,
|
|
|
|
|
r.valid_at AS valid_at,
|
|
|
|
|
r.invalid_at AS invalid_at
|
|
|
|
|
ORDER BY score DESC
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
records, _, _ = await driver.execute_query(
|
|
|
|
|
query,
|
2024-08-23 14:18:45 -04:00
|
|
|
search_vector=search_vector,
|
2024-09-03 13:25:52 -04:00
|
|
|
source_uuid=source_node_uuid,
|
|
|
|
|
target_uuid=target_node_uuid,
|
2024-09-06 12:33:42 -04:00
|
|
|
group_ids=group_ids,
|
2024-08-23 14:18:45 -04:00
|
|
|
limit=limit,
|
|
|
|
|
)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-09-06 12:33:42 -04:00
|
|
|
edges = [get_entity_edge_from_record(record) for record in records]
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
return edges
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
|
2024-09-16 14:03:05 -04:00
|
|
|
async def node_fulltext_search(
|
2024-09-06 12:33:42 -04:00
|
|
|
driver: AsyncDriver,
|
2024-09-16 14:03:05 -04:00
|
|
|
query: str,
|
2024-09-24 15:55:30 -04:00
|
|
|
group_ids: list[str] | None = None,
|
2024-09-16 14:03:05 -04:00
|
|
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
|
|
|
) -> list[EntityNode]:
|
|
|
|
|
# BM25 search to get top nodes
|
|
|
|
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
|
|
|
records, _, _ = await driver.execute_query(
|
|
|
|
|
"""
|
|
|
|
|
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
|
|
|
|
YIELD node AS n, score
|
2024-09-24 15:55:30 -04:00
|
|
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
2024-09-16 14:03:05 -04:00
|
|
|
RETURN
|
|
|
|
|
n.uuid AS uuid,
|
|
|
|
|
n.group_id AS group_id,
|
|
|
|
|
n.name AS name,
|
|
|
|
|
n.name_embedding AS name_embedding,
|
|
|
|
|
n.created_at AS created_at,
|
|
|
|
|
n.summary AS summary
|
|
|
|
|
ORDER BY score DESC
|
|
|
|
|
LIMIT $limit
|
|
|
|
|
""",
|
|
|
|
|
query=fuzzy_query,
|
|
|
|
|
group_ids=group_ids,
|
|
|
|
|
limit=limit,
|
|
|
|
|
)
|
|
|
|
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
|
|
|
|
|
|
|
|
return nodes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def node_similarity_search(
|
|
|
|
|
driver: AsyncDriver,
|
|
|
|
|
search_vector: list[float],
|
2024-09-24 15:55:30 -04:00
|
|
|
group_ids: list[str] | None = None,
|
2024-09-06 12:33:42 -04:00
|
|
|
limit=RELEVANT_SCHEMA_LIMIT,
|
2024-08-18 13:22:31 -04:00
|
|
|
) -> list[EntityNode]:
|
2024-08-23 14:18:45 -04:00
|
|
|
# vector similarity search over entity names
|
|
|
|
|
records, _, _ = await driver.execute_query(
|
|
|
|
|
"""
|
2024-08-21 12:03:32 -04:00
|
|
|
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
2024-08-18 13:22:31 -04:00
|
|
|
YIELD node AS n, score
|
2024-09-24 15:55:30 -04:00
|
|
|
MATCH (n:Entity)
|
|
|
|
|
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
2024-08-18 13:22:31 -04:00
|
|
|
RETURN
|
2024-09-06 12:33:42 -04:00
|
|
|
n.uuid As uuid,
|
|
|
|
|
n.group_id AS group_id,
|
2024-08-18 13:22:31 -04:00
|
|
|
n.name AS name,
|
2024-09-05 14:09:19 -04:00
|
|
|
n.name_embedding AS name_embedding,
|
2024-08-18 13:22:31 -04:00
|
|
|
n.created_at AS created_at,
|
|
|
|
|
n.summary AS summary
|
|
|
|
|
ORDER BY score DESC
|
|
|
|
|
""",
|
2024-08-23 14:18:45 -04:00
|
|
|
search_vector=search_vector,
|
2024-09-06 12:33:42 -04:00
|
|
|
group_ids=group_ids,
|
2024-08-23 14:18:45 -04:00
|
|
|
limit=limit,
|
|
|
|
|
)
|
2024-09-06 12:33:42 -04:00
|
|
|
nodes = [get_entity_node_from_record(record) for record in records]
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
return nodes
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
|
2024-09-16 14:03:05 -04:00
|
|
|
async def community_fulltext_search(
|
2024-09-06 12:33:42 -04:00
|
|
|
driver: AsyncDriver,
|
2024-09-16 14:03:05 -04:00
|
|
|
query: str,
|
2024-09-24 15:55:30 -04:00
|
|
|
group_ids: list[str] | None = None,
|
2024-09-06 12:33:42 -04:00
|
|
|
limit=RELEVANT_SCHEMA_LIMIT,
|
2024-09-16 14:03:05 -04:00
|
|
|
) -> list[CommunityNode]:
|
|
|
|
|
# BM25 search to get top communities
|
2024-08-26 10:30:22 -04:00
|
|
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
2024-08-23 14:18:45 -04:00
|
|
|
records, _, _ = await driver.execute_query(
|
|
|
|
|
"""
|
2024-09-16 14:03:05 -04:00
|
|
|
CALL db.index.fulltext.queryNodes("community_name", $query)
|
|
|
|
|
YIELD node AS comm, score
|
2024-09-24 15:55:30 -04:00
|
|
|
MATCH (comm:Community)
|
|
|
|
|
WHERE $group_ids IS NULL OR comm.group_id in $group_ids
|
2024-08-22 14:26:26 -04:00
|
|
|
RETURN
|
2024-09-16 14:03:05 -04:00
|
|
|
comm.uuid AS uuid,
|
|
|
|
|
comm.group_id AS group_id,
|
|
|
|
|
comm.name AS name,
|
|
|
|
|
comm.name_embedding AS name_embedding,
|
|
|
|
|
comm.created_at AS created_at,
|
|
|
|
|
comm.summary AS summary
|
2024-08-18 13:22:31 -04:00
|
|
|
ORDER BY score DESC
|
2024-08-21 12:03:32 -04:00
|
|
|
LIMIT $limit
|
2024-08-18 13:22:31 -04:00
|
|
|
""",
|
2024-08-23 14:18:45 -04:00
|
|
|
query=fuzzy_query,
|
2024-09-06 12:33:42 -04:00
|
|
|
group_ids=group_ids,
|
2024-08-23 14:18:45 -04:00
|
|
|
limit=limit,
|
|
|
|
|
)
|
2024-09-16 14:03:05 -04:00
|
|
|
communities = [get_community_node_from_record(record) for record in records]
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-09-16 14:03:05 -04:00
|
|
|
return communities
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
|
2024-09-16 14:03:05 -04:00
|
|
|
async def community_similarity_search(
|
2024-09-05 12:05:44 -04:00
|
|
|
driver: AsyncDriver,
|
2024-09-16 14:03:05 -04:00
|
|
|
search_vector: list[float],
|
2024-09-24 15:55:30 -04:00
|
|
|
group_ids: list[str] | None = None,
|
2024-09-05 12:05:44 -04:00
|
|
|
limit=RELEVANT_SCHEMA_LIMIT,
|
2024-09-16 14:03:05 -04:00
|
|
|
) -> list[CommunityNode]:
|
|
|
|
|
# vector similarity search over entity names
|
2024-09-04 10:05:45 -04:00
|
|
|
records, _, _ = await driver.execute_query(
|
2024-09-16 14:03:05 -04:00
|
|
|
"""
|
|
|
|
|
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
|
|
|
|
|
YIELD node AS comm, score
|
2024-09-24 15:55:30 -04:00
|
|
|
MATCH (comm:Community)
|
|
|
|
|
WHERE $group_ids IS NULL OR comm.group_id IN $group_ids
|
2024-09-16 14:03:05 -04:00
|
|
|
RETURN
|
|
|
|
|
comm.uuid As uuid,
|
|
|
|
|
comm.group_id AS group_id,
|
|
|
|
|
comm.name AS name,
|
|
|
|
|
comm.name_embedding AS name_embedding,
|
|
|
|
|
comm.created_at AS created_at,
|
|
|
|
|
comm.summary AS summary
|
|
|
|
|
ORDER BY score DESC
|
|
|
|
|
""",
|
|
|
|
|
search_vector=search_vector,
|
2024-09-06 12:33:42 -04:00
|
|
|
group_ids=group_ids,
|
2024-08-23 14:18:45 -04:00
|
|
|
limit=limit,
|
|
|
|
|
)
|
2024-09-16 14:03:05 -04:00
|
|
|
communities = [get_community_node_from_record(record) for record in records]
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-09-16 14:03:05 -04:00
|
|
|
return communities
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
|
2024-08-26 20:00:28 -07:00
|
|
|
async def hybrid_node_search(
|
2024-09-05 12:05:44 -04:00
|
|
|
queries: list[str],
|
|
|
|
|
embeddings: list[list[float]],
|
|
|
|
|
driver: AsyncDriver,
|
2024-09-24 15:55:30 -04:00
|
|
|
group_ids: list[str] | None = None,
|
2024-09-05 12:05:44 -04:00
|
|
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
2024-08-18 13:22:31 -04:00
|
|
|
) -> list[EntityNode]:
|
2024-08-26 20:00:28 -07:00
|
|
|
"""
|
|
|
|
|
Perform a hybrid search for nodes using both text queries and embeddings.
|
|
|
|
|
|
|
|
|
|
This method combines fulltext search and vector similarity search to find
|
2024-09-04 10:05:45 -04:00
|
|
|
relevant nodes in the graph database. It uses a rrf reranker.
|
2024-08-26 20:00:28 -07:00
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
queries : list[str]
|
|
|
|
|
A list of text queries to search for.
|
|
|
|
|
embeddings : list[list[float]]
|
|
|
|
|
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
|
|
|
|
driver : AsyncDriver
|
|
|
|
|
The Neo4j driver instance for database operations.
|
2024-09-06 12:33:42 -04:00
|
|
|
group_ids : list[str] | None, optional
|
|
|
|
|
The list of group ids to retrieve nodes from.
|
2024-08-26 20:00:28 -07:00
|
|
|
limit : int | None, optional
|
|
|
|
|
The maximum number of results to return per search method. If None, a default limit will be applied.
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
list[EntityNode]
|
|
|
|
|
A list of unique EntityNode objects that match the search criteria.
|
|
|
|
|
|
|
|
|
|
Notes
|
|
|
|
|
-----
|
|
|
|
|
This method performs the following steps:
|
|
|
|
|
1. Executes fulltext searches for each query.
|
|
|
|
|
2. Executes vector similarity searches for each embedding.
|
|
|
|
|
3. Combines and deduplicates the results from both search types.
|
|
|
|
|
4. Logs the performance metrics of the search operation.
|
|
|
|
|
|
|
|
|
|
The search results are deduplicated based on the node UUIDs to ensure
|
|
|
|
|
uniqueness in the returned list. The 'limit' parameter is applied to each
|
|
|
|
|
individual search method before deduplication. If not specified, a default
|
|
|
|
|
limit (defined in the individual search functions) will be used.
|
|
|
|
|
"""
|
|
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
start = time()
|
2024-08-30 10:48:28 -04:00
|
|
|
results: list[list[EntityNode]] = list(
|
|
|
|
|
await asyncio.gather(
|
2024-09-16 14:03:05 -04:00
|
|
|
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
|
|
|
|
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
|
2024-08-30 10:48:28 -04:00
|
|
|
)
|
2024-08-23 14:18:45 -04:00
|
|
|
)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-30 10:48:28 -04:00
|
|
|
node_uuid_map: dict[str, EntityNode] = {
|
|
|
|
|
node.uuid: node for result in results for node in result
|
|
|
|
|
}
|
|
|
|
|
result_uuids = [[node.uuid for node in result] for result in results]
|
|
|
|
|
|
|
|
|
|
ranked_uuids = rrf(result_uuids)
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-30 10:48:28 -04:00
|
|
|
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
end = time()
|
2024-08-30 10:48:28 -04:00
|
|
|
logger.info(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
|
2024-08-26 20:00:28 -07:00
|
|
|
return relevant_nodes
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-26 20:00:28 -07:00
|
|
|
|
|
|
|
|
async def get_relevant_nodes(
|
2024-09-05 12:05:44 -04:00
|
|
|
nodes: list[EntityNode],
|
|
|
|
|
driver: AsyncDriver,
|
2024-08-26 20:00:28 -07:00
|
|
|
) -> list[EntityNode]:
|
|
|
|
|
"""
|
|
|
|
|
Retrieve relevant nodes based on the provided list of EntityNodes.
|
|
|
|
|
|
|
|
|
|
This method performs a hybrid search using both the names and embeddings
|
|
|
|
|
of the input nodes to find relevant nodes in the graph database.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
nodes : list[EntityNode]
|
|
|
|
|
A list of EntityNode objects to use as the basis for the search.
|
|
|
|
|
driver : AsyncDriver
|
|
|
|
|
The Neo4j driver instance for database operations.
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
list[EntityNode]
|
|
|
|
|
A list of EntityNode objects that are deemed relevant based on the input nodes.
|
|
|
|
|
|
|
|
|
|
Notes
|
|
|
|
|
-----
|
|
|
|
|
This method uses the hybrid_node_search function to perform the search,
|
|
|
|
|
which combines fulltext search and vector similarity search.
|
|
|
|
|
It extracts the names and name embeddings (if available) from the input nodes
|
|
|
|
|
to use as search criteria.
|
|
|
|
|
"""
|
|
|
|
|
relevant_nodes = await hybrid_node_search(
|
|
|
|
|
[node.name for node in nodes],
|
|
|
|
|
[node.name_embedding for node in nodes if node.name_embedding is not None],
|
|
|
|
|
driver,
|
2024-09-06 12:33:42 -04:00
|
|
|
[node.group_id for node in nodes],
|
2024-08-26 20:00:28 -07:00
|
|
|
)
|
2024-08-23 14:18:45 -04:00
|
|
|
return relevant_nodes
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_relevant_edges(
|
2024-09-05 12:05:44 -04:00
|
|
|
driver: AsyncDriver,
|
|
|
|
|
edges: list[EntityEdge],
|
|
|
|
|
source_node_uuid: str | None,
|
|
|
|
|
target_node_uuid: str | None,
|
|
|
|
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
2024-08-18 13:22:31 -04:00
|
|
|
) -> list[EntityEdge]:
|
2024-08-23 14:18:45 -04:00
|
|
|
start = time()
|
|
|
|
|
relevant_edges: list[EntityEdge] = []
|
|
|
|
|
relevant_edge_uuids = set()
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
results = await asyncio.gather(
|
|
|
|
|
*[
|
2024-09-03 13:25:52 -04:00
|
|
|
edge_similarity_search(
|
2024-09-06 12:33:42 -04:00
|
|
|
driver,
|
|
|
|
|
edge.fact_embedding,
|
|
|
|
|
source_node_uuid,
|
|
|
|
|
target_node_uuid,
|
|
|
|
|
[edge.group_id],
|
|
|
|
|
limit,
|
2024-09-03 13:25:52 -04:00
|
|
|
)
|
2024-08-23 14:18:45 -04:00
|
|
|
for edge in edges
|
|
|
|
|
if edge.fact_embedding is not None
|
|
|
|
|
],
|
2024-09-03 13:25:52 -04:00
|
|
|
*[
|
2024-09-06 12:33:42 -04:00
|
|
|
edge_fulltext_search(
|
|
|
|
|
driver, edge.fact, source_node_uuid, target_node_uuid, [edge.group_id], limit
|
|
|
|
|
)
|
2024-09-03 13:25:52 -04:00
|
|
|
for edge in edges
|
|
|
|
|
],
|
2024-08-23 14:18:45 -04:00
|
|
|
)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
for result in results:
|
|
|
|
|
for edge in result:
|
|
|
|
|
if edge.uuid in relevant_edge_uuids:
|
|
|
|
|
continue
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
relevant_edge_uuids.add(edge.uuid)
|
|
|
|
|
relevant_edges.append(edge)
|
2024-08-21 12:03:32 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
end = time()
|
|
|
|
|
logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms')
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
return relevant_edges
|
2024-08-22 14:26:26 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# takes in a list of rankings of uuids
|
|
|
|
|
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
2024-08-26 18:34:57 -04:00
|
|
|
scores: dict[str, float] = defaultdict(float)
|
2024-08-23 14:18:45 -04:00
|
|
|
for result in results:
|
|
|
|
|
for i, uuid in enumerate(result):
|
|
|
|
|
scores[uuid] += 1 / (i + rank_const)
|
2024-08-22 14:26:26 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
scored_uuids = [term for term in scores.items()]
|
|
|
|
|
scored_uuids.sort(reverse=True, key=lambda term: term[1])
|
2024-08-22 14:26:26 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
sorted_uuids = [term[0] for term in scored_uuids]
|
2024-08-22 14:26:26 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
return sorted_uuids
|
2024-08-26 18:34:57 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def node_distance_reranker(
|
2024-09-24 15:55:30 -04:00
|
|
|
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
2024-08-26 18:34:57 -04:00
|
|
|
) -> list[str]:
|
2024-09-24 15:55:30 -04:00
|
|
|
# filter out node_uuid center node node uuid
|
|
|
|
|
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
|
2024-08-26 18:34:57 -04:00
|
|
|
scores: dict[str, float] = {}
|
|
|
|
|
|
2024-09-12 11:23:45 -04:00
|
|
|
# Find the shortest path to center node
|
2024-09-24 15:55:30 -04:00
|
|
|
query = Query("""
|
2024-09-16 14:03:05 -04:00
|
|
|
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
|
|
|
|
|
RETURN length(p) AS score
|
2024-09-12 11:23:45 -04:00
|
|
|
""")
|
2024-08-26 18:34:57 -04:00
|
|
|
|
2024-09-12 11:23:45 -04:00
|
|
|
path_results = await asyncio.gather(
|
|
|
|
|
*[
|
|
|
|
|
driver.execute_query(
|
|
|
|
|
query,
|
2024-09-16 14:03:05 -04:00
|
|
|
node_uuid=uuid,
|
2024-09-12 11:23:45 -04:00
|
|
|
center_uuid=center_node_uuid,
|
|
|
|
|
)
|
2024-09-24 15:55:30 -04:00
|
|
|
for uuid in filtered_uuids
|
2024-09-12 11:23:45 -04:00
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
2024-09-24 15:55:30 -04:00
|
|
|
for uuid, result in zip(filtered_uuids, path_results):
|
2024-09-12 11:23:45 -04:00
|
|
|
records = result[0]
|
|
|
|
|
record = records[0] if len(records) > 0 else None
|
|
|
|
|
distance: float = record['score'] if record is not None else float('inf')
|
2024-09-16 14:03:05 -04:00
|
|
|
scores[uuid] = distance
|
2024-08-26 18:34:57 -04:00
|
|
|
|
|
|
|
|
# rerank on shortest distance
|
2024-09-24 15:55:30 -04:00
|
|
|
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
2024-08-26 18:34:57 -04:00
|
|
|
|
2024-09-24 15:55:30 -04:00
|
|
|
# add back in filtered center uuids
|
|
|
|
|
filtered_uuids = [center_node_uuid] + filtered_uuids
|
|
|
|
|
|
|
|
|
|
return filtered_uuids
|
2024-09-18 15:44:28 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
|
|
|
|
|
# use rrf as a preliminary ranker
|
|
|
|
|
sorted_uuids = rrf(node_uuids)
|
|
|
|
|
scores: dict[str, float] = {}
|
|
|
|
|
|
|
|
|
|
# Find the shortest path to center node
|
|
|
|
|
query = Query("""
|
|
|
|
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
|
|
|
|
|
RETURN count(*) AS score
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
result_scores = await asyncio.gather(
|
|
|
|
|
*[
|
|
|
|
|
driver.execute_query(
|
|
|
|
|
query,
|
|
|
|
|
node_uuid=uuid,
|
|
|
|
|
)
|
|
|
|
|
for uuid in sorted_uuids
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for uuid, result in zip(sorted_uuids, result_scores):
|
|
|
|
|
record = result[0][0]
|
|
|
|
|
scores[uuid] = record['score']
|
|
|
|
|
|
|
|
|
|
# rerank on shortest distance
|
|
|
|
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
|
|
|
|
|
|
|
|
return sorted_uuids
|