Mentions reranker (#124)

* documentation update

* update communities

* mentions reranker

* fix episode edge mentions

* get episode mentions

* add communities to mentions endpoint

* rebase

* defaults episodes to empty list

* update
This commit is contained in:
Preston Rasmussen 2024-09-18 15:44:28 -04:00 committed by GitHub
parent d133c39313
commit e398f95612
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 184 additions and 7 deletions

View File

@ -83,6 +83,7 @@ async def main(use_bulk: bool = True):
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
update_communities=True,
)
return

View File

@ -109,13 +109,36 @@ class EpisodicEdge(Edge):
raise EdgeNotFoundError(uuid)
return edges[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuids=uuids,
)
edges = [get_episodic_edge_from_record(record) for record in records]
logger.info(f'Found Edges: {uuids}')
if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
return edges
class EntityEdge(Edge):
name: str = Field(description='name of the edge, relation name')
fact: str = Field(description='fact representing the edge and nodes that it connects')
fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
episodes: list[str] | None = Field(
default=None,
episodes: list[str] = Field(
default=[],
description='list of episode ids that reference these entity edges',
)
expired_at: datetime | None = Field(
@ -197,6 +220,36 @@ class EntityEdge(Edge):
raise EdgeNotFoundError(uuid)
return edges[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
e.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.group_id AS group_id,
e.fact AS fact,
e.fact_embedding AS fact_embedding,
e.episodes AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
""",
uuids=uuids,
)
edges = [get_entity_edge_from_record(record) for record in records]
logger.info(f'Found Edges: {uuids}')
if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
return edges
class CommunityEdge(Edge):
async def save(self, driver: AsyncDriver):
@ -239,6 +292,28 @@ class CommunityEdge(Edge):
return edges[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
WHERE e.uuid IN $uuids
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuids=uuids,
)
edges = [get_community_edge_from_record(record) for record in records]
logger.info(f'Found Edges: {uuids}')
return edges
# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:

View File

@ -35,6 +35,8 @@ from graphiti_core.search.search_config_recipes import (
)
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_communities_by_nodes,
get_mentioned_nodes,
get_relevant_edges,
get_relevant_nodes,
)
@ -249,8 +251,6 @@ class Graphiti:
An id for the graph partition the episode is a part of.
uuid : str | None
Optional uuid of the episode.
update_communities: bool
Optional. Determines if we should update communities
Returns
-------
@ -413,6 +413,8 @@ class Graphiti:
logger.info(f'Built episodic edges: {episodic_edges}')
episode.entity_edges = [edge.uuid for edge in entity_edges]
# Future optimization would be using batch operations to save nodes and edges
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
@ -680,3 +682,19 @@ class Graphiti:
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
).nodes
return nodes
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
edges_list = await asyncio.gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
)
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
nodes = await get_mentioned_nodes(self.driver, episodes)
communities = await get_communities_by_nodes(self.driver, nodes)
return SearchResults(edges=edges, nodes=nodes, communities=communities)

View File

@ -170,7 +170,8 @@ class EpisodicNode(Node):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.uuid IN $uuids
RETURN e.content AS content,
RETURN DISTINCT
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,

View File

@ -42,6 +42,7 @@ from graphiti_core.search.search_utils import (
community_similarity_search,
edge_fulltext_search,
edge_similarity_search,
episode_mentions_reranker,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
@ -131,7 +132,7 @@ async def edge_search(
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
reranked_uuids: list[str] = []
if config.reranker == EdgeReranker.rrf:
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids = rrf(search_result_uuids)
@ -150,6 +151,9 @@ async def edge_search(
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
if config.reranker == EdgeReranker.episode_mentions:
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
return reranked_edges
@ -189,6 +193,8 @@ async def node_search(
reranked_uuids: list[str] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')

View File

@ -42,11 +42,13 @@ class CommunitySearchMethod(Enum):
class EdgeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'
class NodeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'
class CommunityReranker(Enum):

View File

@ -59,6 +59,14 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
)
)
# performs a hybrid search over edges with episode mention reranking
EDGE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.episode_mentions,
)
)
# performs a hybrid search over nodes with rrf reranking
NODE_HYBRID_SEARCH_RRF = SearchConfig(
node_config=NodeSearchConfig(
@ -75,6 +83,14 @@ NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
)
)
# performs a hybrid search over nodes with episode mentions reranking
NODE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.episode_mentions,
)
)
# performs a hybrid search over communities with rrf reranking
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
community_config=CommunitySearchConfig(

View File

@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 3
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
async def get_mentioned_nodes(
driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
"""
@ -57,6 +59,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
return nodes
async def get_communities_by_nodes(
driver: AsyncDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
RETURN DISTINCT
c.uuid As uuid,
c.group_id AS group_id,
c.name AS name,
c.name_embedding AS name_embedding
c.created_at AS created_at,
c.summary AS summary
""",
uuids=node_uuids,
)
communities = [get_community_node_from_record(record) for record in records]
return communities
async def edge_fulltext_search(
driver: AsyncDriver,
query: str,
@ -634,3 +659,34 @@ async def node_distance_reranker(
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return sorted_uuids
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(node_uuids)
scores: dict[str, float] = {}
# Find the shortest path to center node
query = Query("""
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
RETURN count(*) AS score
""")
result_scores = await asyncio.gather(
*[
driver.execute_query(
query,
node_uuid=uuid,
)
for uuid in sorted_uuids
]
)
for uuid, result in zip(sorted_uuids, result_scores):
record = result[0][0]
scores[uuid] = record['score']
# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return sorted_uuids

View File

@ -163,6 +163,8 @@ async def dedupe_extracted_edges(
if edge.uuid in duplicate_uuid_map:
existing_uuid = duplicate_uuid_map[edge.uuid]
existing_edge = edge_map[existing_uuid]
# Add current episode to the episodes list
existing_edge.episodes += edge.episodes
edges.append(existing_edge)
else:
edges.append(edge)