mirror of
https://github.com/getzep/graphiti.git
synced 2025-12-28 15:45:09 +00:00
Group id fix (#152)
* node distance and group_ids fixed * get all with no group_id passed * push * push * remove comments * mypy * mypy ids * please mypy * trust * last one
This commit is contained in:
parent
cfeb58daba
commit
794b705664
@ -63,28 +63,27 @@ async def main(use_bulk: bool = True):
|
||||
messages = parse_podcast_messages()
|
||||
|
||||
if not use_bulk:
|
||||
for i, message in enumerate(messages[3:4]):
|
||||
for i, message in enumerate(messages[3:14]):
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Podcast Transcript',
|
||||
group_id='1',
|
||||
)
|
||||
|
||||
# build communities
|
||||
await client.build_communities()
|
||||
|
||||
# add additional messages to update communities
|
||||
# for i, message in enumerate(messages[14:20]):
|
||||
# await client.add_episode(
|
||||
# name=f'Message {i}',
|
||||
# episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
# reference_time=message.actual_timestamp,
|
||||
# source_description='Podcast Transcript',
|
||||
# group_id='1',
|
||||
# update_communities=True,
|
||||
# )
|
||||
for i, message in enumerate(messages[14:20]):
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
reference_time=message.actual_timestamp,
|
||||
source_description='Podcast Transcript',
|
||||
group_id='1',
|
||||
update_communities=True,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class Edge(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||
group_id: str | None = Field(description='partition of the graph')
|
||||
group_id: str = Field(description='partition of the graph')
|
||||
source_node_uuid: str
|
||||
target_node_uuid: str
|
||||
created_at: datetime
|
||||
@ -131,7 +131,7 @@ class EpisodicEdge(Edge):
|
||||
return edges
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
|
||||
@ -270,7 +270,7 @@ class EntityEdge(Edge):
|
||||
return edges
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
@ -360,7 +360,7 @@ class CommunityEdge(Edge):
|
||||
return edges
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
|
||||
|
||||
@ -197,7 +197,7 @@ class Graphiti:
|
||||
self,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
) -> list[EpisodicNode]:
|
||||
"""
|
||||
Retrieve the last n episodic nodes from the graph.
|
||||
@ -233,7 +233,7 @@ class Graphiti:
|
||||
source_description: str,
|
||||
reference_time: datetime,
|
||||
source: EpisodeType = EpisodeType.message,
|
||||
group_id: str | None = None,
|
||||
group_id: str = '',
|
||||
uuid: str | None = None,
|
||||
update_communities: bool = False,
|
||||
):
|
||||
@ -446,7 +446,7 @@ class Graphiti:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
|
||||
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
|
||||
"""
|
||||
Process multiple episodes in bulk and update the graph.
|
||||
|
||||
@ -577,7 +577,7 @@ class Graphiti:
|
||||
self,
|
||||
query: str,
|
||||
center_node_uuid: str | None = None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
num_results=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
"""
|
||||
@ -633,7 +633,7 @@ class Graphiti:
|
||||
self,
|
||||
query: str,
|
||||
config: SearchConfig,
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> SearchResults:
|
||||
return await search(
|
||||
@ -644,7 +644,7 @@ class Graphiti:
|
||||
self,
|
||||
query: str,
|
||||
center_node_uuid: str | None = None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
|
||||
@ -29,7 +29,7 @@ from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = 'gpt-4o-2024-08-06'
|
||||
DEFAULT_MODEL = 'gpt-4o-mini'
|
||||
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
|
||||
@ -70,7 +70,7 @@ class EpisodeType(Enum):
|
||||
class Node(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str = Field(description='name of the node')
|
||||
group_id: str | None = Field(description='partition of the graph')
|
||||
group_id: str = Field(description='partition of the graph')
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now())
|
||||
|
||||
@ -186,7 +186,7 @@ class EpisodicNode(Node):
|
||||
return episodes
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
|
||||
@ -281,7 +281,7 @@ class EntityNode(Node):
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity) WHERE n.group_id IN $group_ids
|
||||
@ -374,7 +374,7 @@ class CommunityNode(Node):
|
||||
return communities
|
||||
|
||||
@classmethod
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
|
||||
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community) WHERE n.group_id IN $group_ids
|
||||
|
||||
@ -15,6 +15,7 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from time import time
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
@ -56,7 +57,7 @@ async def search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None,
|
||||
group_ids: list[str] | None,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> SearchResults:
|
||||
@ -103,7 +104,7 @@ async def edge_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None,
|
||||
group_ids: list[str] | None,
|
||||
config: EdgeSearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
@ -140,14 +141,21 @@ async def edge_search(
|
||||
if center_node_uuid is None:
|
||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||
|
||||
source_to_edge_uuid_map = {
|
||||
edge.source_node_uuid: edge.uuid for result in search_results for edge in result
|
||||
}
|
||||
source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
|
||||
# use rrf as a preliminary sort
|
||||
sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results])
|
||||
sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
|
||||
|
||||
# node distance reranking
|
||||
source_to_edge_uuid_map = defaultdict(list)
|
||||
for edge in sorted_results:
|
||||
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
|
||||
|
||||
source_uuids = [edge.source_node_uuid for edge in sorted_results]
|
||||
|
||||
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
|
||||
|
||||
reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
|
||||
for node_uuid in reranked_node_uuids:
|
||||
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
|
||||
|
||||
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
@ -161,7 +169,7 @@ async def node_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None,
|
||||
group_ids: list[str] | None,
|
||||
config: NodeSearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
@ -198,7 +206,9 @@ async def node_search(
|
||||
elif config.reranker == NodeReranker.node_distance:
|
||||
if center_node_uuid is None:
|
||||
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
||||
reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
|
||||
reranked_uuids = await node_distance_reranker(
|
||||
driver, rrf(search_result_uuids), center_node_uuid
|
||||
)
|
||||
|
||||
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
||||
|
||||
@ -209,7 +219,7 @@ async def community_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None,
|
||||
group_ids: list[str] | None,
|
||||
config: CommunitySearchConfig,
|
||||
limit=DEFAULT_SEARCH_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
|
||||
@ -87,7 +87,7 @@ async def edge_fulltext_search(
|
||||
query: str,
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
@ -95,10 +95,7 @@ async def edge_fulltext_search(
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE CASE
|
||||
WHEN $group_ids IS NULL THEN n.group_id IS NULL
|
||||
ELSE n.group_id IN $group_ids
|
||||
END
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
@ -120,10 +117,7 @@ async def edge_fulltext_search(
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE CASE
|
||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
||||
ELSE r.group_id IN $group_ids
|
||||
END
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
@ -144,10 +138,7 @@ async def edge_fulltext_search(
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE CASE
|
||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
||||
ELSE r.group_id IN $group_ids
|
||||
END
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
@ -168,10 +159,7 @@ async def edge_fulltext_search(
|
||||
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE CASE
|
||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
||||
ELSE r.group_id IN $group_ids
|
||||
END
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
@ -209,7 +197,7 @@ async def edge_similarity_search(
|
||||
search_vector: list[float],
|
||||
source_node_uuid: str | None,
|
||||
target_node_uuid: str | None,
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityEdge]:
|
||||
# vector similarity search over embedded facts
|
||||
@ -217,10 +205,7 @@ async def edge_similarity_search(
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE CASE
|
||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
||||
ELSE r.group_id IN $group_ids
|
||||
END
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
@ -242,10 +227,7 @@ async def edge_similarity_search(
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE CASE
|
||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
||||
ELSE r.group_id IN $group_ids
|
||||
END
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
@ -266,10 +248,7 @@ async def edge_similarity_search(
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
||||
WHERE CASE
|
||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
||||
ELSE r.group_id IN $group_ids
|
||||
END
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
@ -290,10 +269,7 @@ async def edge_similarity_search(
|
||||
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
||||
YIELD relationship AS rel, score
|
||||
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
||||
WHERE CASE
|
||||
WHEN $group_ids IS NULL THEN r.group_id IS NULL
|
||||
ELSE r.group_id IN $group_ids
|
||||
END
|
||||
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
|
||||
RETURN
|
||||
r.uuid AS uuid,
|
||||
r.group_id AS group_id,
|
||||
@ -327,7 +303,7 @@ async def edge_similarity_search(
|
||||
async def node_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
@ -336,10 +312,7 @@ async def node_fulltext_search(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
|
||||
YIELD node AS n, score
|
||||
WHERE CASE
|
||||
WHEN $group_ids IS NULL THEN n.group_id IS NULL
|
||||
ELSE n.group_id IN $group_ids
|
||||
END
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
RETURN
|
||||
n.uuid AS uuid,
|
||||
n.group_id AS group_id,
|
||||
@ -362,17 +335,16 @@ async def node_fulltext_search(
|
||||
async def node_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
|
||||
YIELD node AS n, score
|
||||
MATCH (n WHERE n.group_id IN $group_ids)
|
||||
MATCH (n:Entity)
|
||||
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.group_id AS group_id,
|
||||
@ -394,18 +366,17 @@ async def node_similarity_search(
|
||||
async def community_fulltext_search(
|
||||
driver: AsyncDriver,
|
||||
query: str,
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# BM25 search to get top communities
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("community_name", $query)
|
||||
YIELD node AS comm, score
|
||||
MATCH (comm WHERE comm.group_id in $group_ids)
|
||||
MATCH (comm:Community)
|
||||
WHERE $group_ids IS NULL OR comm.group_id in $group_ids
|
||||
RETURN
|
||||
comm.uuid AS uuid,
|
||||
comm.group_id AS group_id,
|
||||
@ -428,17 +399,16 @@ async def community_fulltext_search(
|
||||
async def community_similarity_search(
|
||||
driver: AsyncDriver,
|
||||
search_vector: list[float],
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
limit=RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[CommunityNode]:
|
||||
group_ids = group_ids if group_ids is not None else [None]
|
||||
|
||||
# vector similarity search over entity names
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
|
||||
YIELD node AS comm, score
|
||||
MATCH (comm WHERE comm.group_id IN $group_ids)
|
||||
MATCH (comm:Community)
|
||||
WHERE $group_ids IS NULL OR comm.group_id IN $group_ids
|
||||
RETURN
|
||||
comm.uuid As uuid,
|
||||
comm.group_id AS group_id,
|
||||
@ -461,7 +431,7 @@ async def hybrid_node_search(
|
||||
queries: list[str],
|
||||
embeddings: list[list[float]],
|
||||
driver: AsyncDriver,
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
limit: int = RELEVANT_SCHEMA_LIMIT,
|
||||
) -> list[EntityNode]:
|
||||
"""
|
||||
@ -503,7 +473,6 @@ async def hybrid_node_search(
|
||||
"""
|
||||
|
||||
start = time()
|
||||
|
||||
results: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
|
||||
@ -625,14 +594,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
||||
|
||||
|
||||
async def node_distance_reranker(
|
||||
driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str
|
||||
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
|
||||
) -> list[str]:
|
||||
# use rrf as a preliminary ranker
|
||||
sorted_uuids = rrf(node_uuids)
|
||||
# filter out node_uuid center node node uuid
|
||||
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
|
||||
scores: dict[str, float] = {}
|
||||
|
||||
# Find the shortest path to center node
|
||||
query = Query("""
|
||||
query = Query("""
|
||||
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
|
||||
RETURN length(p) AS score
|
||||
""")
|
||||
@ -644,21 +613,23 @@ async def node_distance_reranker(
|
||||
node_uuid=uuid,
|
||||
center_uuid=center_node_uuid,
|
||||
)
|
||||
for uuid in sorted_uuids
|
||||
for uuid in filtered_uuids
|
||||
]
|
||||
)
|
||||
|
||||
for uuid, result in zip(sorted_uuids, path_results):
|
||||
for uuid, result in zip(filtered_uuids, path_results):
|
||||
records = result[0]
|
||||
record = records[0] if len(records) > 0 else None
|
||||
distance: float = record['score'] if record is not None else float('inf')
|
||||
distance = 0 if uuid == center_node_uuid else distance
|
||||
scores[uuid] = distance
|
||||
|
||||
# rerank on shortest distance
|
||||
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
||||
|
||||
return sorted_uuids
|
||||
# add back in filtered center uuids
|
||||
filtered_uuids = [center_node_uuid] + filtered_uuids
|
||||
|
||||
return filtered_uuids
|
||||
|
||||
|
||||
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:
|
||||
|
||||
@ -154,7 +154,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
|
||||
|
||||
|
||||
async def build_community(
|
||||
llm_client: LLMClient, community_cluster: list[EntityNode]
|
||||
llm_client: LLMClient, community_cluster: list[EntityNode]
|
||||
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
||||
summaries = [entity.summary for entity in community_cluster]
|
||||
length = len(summaries)
|
||||
@ -168,7 +168,7 @@ async def build_community(
|
||||
*[
|
||||
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
||||
for left_summary, right_summary in zip(
|
||||
summaries[: int(length / 2)], summaries[int(length / 2):]
|
||||
summaries[: int(length / 2)], summaries[int(length / 2) :]
|
||||
)
|
||||
]
|
||||
)
|
||||
@ -196,7 +196,7 @@ async def build_community(
|
||||
|
||||
|
||||
async def build_communities(
|
||||
driver: AsyncDriver, llm_client: LLMClient
|
||||
driver: AsyncDriver, llm_client: LLMClient
|
||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||
community_clusters = await get_community_clusters(driver)
|
||||
|
||||
@ -227,7 +227,7 @@ async def remove_communities(driver: AsyncDriver):
|
||||
|
||||
|
||||
async def determine_entity_community(
|
||||
driver: AsyncDriver, entity: EntityNode
|
||||
driver: AsyncDriver, entity: EntityNode
|
||||
) -> tuple[CommunityNode | None, bool]:
|
||||
# Check if the node is already part of a community
|
||||
records, _, _ = await driver.execute_query(
|
||||
@ -288,7 +288,7 @@ async def determine_entity_community(
|
||||
|
||||
|
||||
async def update_community(
|
||||
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
|
||||
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
|
||||
):
|
||||
community, is_new = await determine_entity_community(driver, entity)
|
||||
|
||||
@ -307,4 +307,4 @@ async def update_community(
|
||||
|
||||
await community.generate_name_embedding(embedder)
|
||||
|
||||
await community.save(driver)
|
||||
await community.save(driver)
|
||||
|
||||
@ -73,7 +73,7 @@ async def extract_edges(
|
||||
episode: EpisodicNode,
|
||||
nodes: list[EntityNode],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
group_id: str | None,
|
||||
group_id: str = '',
|
||||
) -> list[EntityEdge]:
|
||||
start = time()
|
||||
|
||||
|
||||
@ -101,7 +101,7 @@ async def retrieve_episodes(
|
||||
driver: AsyncDriver,
|
||||
reference_time: datetime,
|
||||
last_n: int = EPISODE_WINDOW_LEN,
|
||||
group_ids: list[str | None] | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
) -> list[EpisodicNode]:
|
||||
"""
|
||||
Retrieve the last n episodic nodes from the graph.
|
||||
@ -119,7 +119,8 @@ async def retrieve_episodes(
|
||||
"""
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids
|
||||
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
|
||||
AND ($group_ids IS NULL) OR e.group_id in $group_ids
|
||||
RETURN e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
|
||||
@ -76,16 +76,18 @@ async def test_graphiti_init():
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
await graphiti.build_communities()
|
||||
|
||||
edges = await graphiti.search('tania tetlow', group_ids=['1'])
|
||||
edges = await graphiti.search(
|
||||
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
|
||||
)
|
||||
|
||||
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
|
||||
|
||||
edges = await graphiti.search('issues with higher ed', group_ids=['1'])
|
||||
edges = await graphiti.search('issues with higher ed', group_ids=None)
|
||||
|
||||
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
||||
|
||||
results = await graphiti._search(
|
||||
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=['1']
|
||||
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=None
|
||||
)
|
||||
pretty_results = {
|
||||
'edges': [edge.fact for edge in results.edges],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user