mirror of
https://github.com/getzep/graphiti.git
synced 2025-12-27 15:13:30 +00:00
Mentions reranker (#124)
* documentation update * update communities * mentions reranker * fix episode edge mentions * get episode mentions * add communities to mentions endpoint * rebase * defaults episodes to empty list * update
This commit is contained in:
parent
d133c39313
commit
e398f95612
@ -83,6 +83,7 @@ async def main(use_bulk: bool = True):
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Podcast Transcript',
|
||||
group_id='1',
|
||||
update_communities=True,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@ -109,13 +109,36 @@ class EpisodicEdge(Edge):
|
||||
raise EdgeNotFoundError(uuid)
|
||||
return edges[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
|
||||
edges = [get_episodic_edge_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Edges: {uuids}')
|
||||
if len(edges) == 0:
|
||||
raise EdgeNotFoundError(uuids[0])
|
||||
return edges
|
||||
|
||||
|
||||
class EntityEdge(Edge):
|
||||
name: str = Field(description='name of the edge, relation name')
|
||||
fact: str = Field(description='fact representing the edge and nodes that it connects')
|
||||
fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
|
||||
episodes: list[str] | None = Field(
|
||||
default=None,
|
||||
episodes: list[str] = Field(
|
||||
default=[],
|
||||
description='list of episode ids that reference these entity edges',
|
||||
)
|
||||
expired_at: datetime | None = Field(
|
||||
@ -197,6 +220,36 @@ class EntityEdge(Edge):
|
||||
raise EdgeNotFoundError(uuid)
|
||||
return edges[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
e.uuid AS uuid,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.fact AS fact,
|
||||
e.fact_embedding AS fact_embedding,
|
||||
e.episodes AS episodes,
|
||||
e.expired_at AS expired_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.invalid_at AS invalid_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
|
||||
edges = [get_entity_edge_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Edges: {uuids}')
|
||||
if len(edges) == 0:
|
||||
raise EdgeNotFoundError(uuids[0])
|
||||
return edges
|
||||
|
||||
|
||||
class CommunityEdge(Edge):
|
||||
async def save(self, driver: AsyncDriver):
|
||||
@ -239,6 +292,28 @@ class CommunityEdge(Edge):
|
||||
|
||||
return edges[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
||||
WHERE e.uuid IN $uuids
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
|
||||
edges = [get_community_edge_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Edges: {uuids}')
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
# Edge helpers
|
||||
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
||||
|
||||
@ -35,6 +35,8 @@ from graphiti_core.search.search_config_recipes import (
|
||||
)
|
||||
from graphiti_core.search.search_utils import (
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
get_communities_by_nodes,
|
||||
get_mentioned_nodes,
|
||||
get_relevant_edges,
|
||||
get_relevant_nodes,
|
||||
)
|
||||
@ -249,8 +251,6 @@ class Graphiti:
|
||||
An id for the graph partition the episode is a part of.
|
||||
uuid : str | None
|
||||
Optional uuid of the episode.
|
||||
update_communities: bool
|
||||
Optional. Determines if we should update communities
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -413,6 +413,8 @@ class Graphiti:
|
||||
|
||||
logger.info(f'Built episodic edges: {episodic_edges}')
|
||||
|
||||
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
||||
|
||||
# Future optimization would be using batch operations to save nodes and edges
|
||||
await episode.save(self.driver)
|
||||
await asyncio.gather(*[node.save(self.driver) for node in nodes])
|
||||
@ -680,3 +682,19 @@ class Graphiti:
|
||||
await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
|
||||
).nodes
|
||||
return nodes
|
||||
|
||||
|
||||
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
|
||||
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
||||
|
||||
edges_list = await asyncio.gather(
|
||||
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
||||
)
|
||||
|
||||
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
||||
|
||||
nodes = await get_mentioned_nodes(self.driver, episodes)
|
||||
|
||||
communities = await get_communities_by_nodes(self.driver, nodes)
|
||||
|
||||
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
||||
|
||||
@ -170,7 +170,8 @@ class EpisodicNode(Node):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
||||
RETURN e.content AS content,
|
||||
RETURN DISTINCT
|
||||
e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
|
||||
@ -42,6 +42,7 @@ from graphiti_core.search.search_utils import (
|
||||
community_similarity_search,
|
||||
edge_fulltext_search,
|
||||
edge_similarity_search,
|
||||
episode_mentions_reranker,
|
||||
node_distance_reranker,
|
||||
node_fulltext_search,
|
||||
node_similarity_search,
|
||||
@ -131,7 +132,7 @@ async def edge_search(
|
||||
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
||||
|
||||
reranked_uuids: list[str] = []
|
||||
if config.reranker == EdgeReranker.rrf:
|
||||
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
||||
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
||||
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
@ -150,6 +151,9 @@ async def edge_search(
|
||||
|
||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
if config.reranker == EdgeReranker.episode_mentions:
|
||||
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
|
||||
|
||||
return reranked_edges
|
||||
|
||||
|
||||
@ -189,6 +193,8 @@ async def node_search(
|
||||
reranked_uuids: list[str] = []
|
||||
if config.reranker == NodeReranker.rrf:
|
||||
reranked_uuids = rrf(search_result_uuids)
|
||||
elif config.reranker == NodeReranker.episode_mentions:
|
||||
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
|
||||
elif config.reranker == NodeReranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||
|
||||
@ -42,11 +42,13 @@ class CommunitySearchMethod(Enum):
|
||||
class EdgeReranker(Enum):
|
||||
rrf = 'reciprocal_rank_fusion'
|
||||
node_distance = 'node_distance'
|
||||
episode_mentions = 'episode_mentions'
|
||||
|
||||
|
||||
class NodeReranker(Enum):
|
||||
rrf = 'reciprocal_rank_fusion'
|
||||
node_distance = 'node_distance'
|
||||
episode_mentions = 'episode_mentions'
|
||||
|
||||
|
||||
class CommunityReranker(Enum):
|
||||
|
||||
@ -59,6 +59,14 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||
)
|
||||
)
|
||||
|
||||
# performs a hybrid search over edges with episode mention reranking
|
||||
EDGE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
|
||||
edge_config=EdgeSearchConfig(
|
||||
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
|
||||
reranker=EdgeReranker.episode_mentions,
|
||||
)
|
||||
)
|
||||
|
||||
# performs a hybrid search over nodes with rrf reranking
|
||||
NODE_HYBRID_SEARCH_RRF = SearchConfig(
|
||||
node_config=NodeSearchConfig(
|
||||
@ -75,6 +83,14 @@ NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
|
||||
)
|
||||
)
|
||||
|
||||
# performs a hybrid search over nodes with episode mentions reranking
|
||||
NODE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
|
||||
node_config=NodeSearchConfig(
|
||||
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
||||
reranker=NodeReranker.episode_mentions,
|
||||
)
|
||||
)
|
||||
|
||||
# performs a hybrid search over communities with rrf reranking
|
||||
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
|
||||
community_config=CommunitySearchConfig(
|
||||
|
||||
@ -36,7 +36,9 @@ logger = logging.getLogger(__name__)
|
||||
RELEVANT_SCHEMA_LIMIT = 3
|
||||
|
||||
|
||||
async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
|
||||
async def get_mentioned_nodes(
|
||||
driver: AsyncDriver, episodes: list[EpisodicNode]
|
||||
) -> list[EntityNode]:
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
@ -57,6 +59,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode])
|
||||
return nodes
|
||||
|
||||
|
||||
async def get_communities_by_nodes(
|
||||
driver: AsyncDriver, nodes: list[EntityNode]
|
||||
) -> list[CommunityNode]:
|
||||
node_uuids = [node.uuid for node in nodes]
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
||||
RETURN DISTINCT
|
||||
c.uuid As uuid,
|
||||
c.group_id AS group_id,
|
||||
c.name AS name,
|
||||
c.name_embedding AS name_embedding
|
||||
c.created_at AS created_at,
|
||||
c.summary AS summary
|
||||
""",
|
||||
uuids=node_uuids,
|
||||
)
|
||||
|
||||
communities = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
return communities
|
||||
|
||||
|
||||
async def edge_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
@ -634,3 +659,34 @@ async def node_distance_reranker(
|
||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
||||
return sorted_uuids
|
||||
|
||||
|
||||
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
|
||||
# use rrf as a preliminary ranker
|
||||
sorted_uuids = rrf(node_uuids)
|
||||
scores: dict[str, float] = {}
|
||||
|
||||
# Find the shortest path to center node
|
||||
query = Query("""
|
||||
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid})
|
||||
RETURN count(*) AS score
|
||||
""")
|
||||
|
||||
result_scores = await asyncio.gather(
|
||||
*[
|
||||
driver.execute_query(
|
||||
query,
|
||||
node_uuid=uuid,
|
||||
)
|
||||
for uuid in sorted_uuids
|
||||
]
|
||||
)
|
||||
|
||||
for uuid, result in zip(sorted_uuids, result_scores):
|
||||
record = result[0][0]
|
||||
scores[uuid] = record['score']
|
||||
|
||||
# rerank on shortest distance
|
||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
||||
return sorted_uuids
|
||||
|
||||
@ -163,6 +163,8 @@ async def dedupe_extracted_edges(
|
||||
if edge.uuid in duplicate_uuid_map:
|
||||
existing_uuid = duplicate_uuid_map[edge.uuid]
|
||||
existing_edge = edge_map[existing_uuid]
|
||||
# Add current episode to the episodes list
|
||||
existing_edge.episodes += edge.episodes
|
||||
edges.append(existing_edge)
|
||||
else:
|
||||
edges.append(edge)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user