From 794b705664bf7282e109dbfc07484ca8bec49399 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:55:30 -0400 Subject: [PATCH] Group id fix (#152) * node distance and group_ids fixed * get all with no group_id passed * push * push * remove comments * mypy * mypy ids * please mypy * trust * last one --- examples/podcast/podcast_runner.py | 21 ++-- graphiti_core/edges.py | 8 +- graphiti_core/graphiti.py | 12 +-- graphiti_core/llm_client/openai_client.py | 2 +- graphiti_core/nodes.py | 8 +- graphiti_core/search/search.py | 30 ++++-- graphiti_core/search/search_utils.py | 95 +++++++------------ .../utils/maintenance/community_operations.py | 12 +-- .../utils/maintenance/edge_operations.py | 2 +- .../maintenance/graph_data_operations.py | 5 +- tests/test_graphiti_int.py | 8 +- 11 files changed, 93 insertions(+), 110 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 90a4a205..43ad29fc 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -63,28 +63,27 @@ async def main(use_bulk: bool = True): messages = parse_podcast_messages() if not use_bulk: - for i, message in enumerate(messages[3:4]): + for i, message in enumerate(messages[3:14]): await client.add_episode( name=f'Message {i}', episode_body=f'{message.speaker_name} ({message.role}): {message.content}', reference_time=message.actual_timestamp, source_description='Podcast Transcript', - group_id='1', ) # build communities await client.build_communities() # add additional messages to update communities - # for i, message in enumerate(messages[14:20]): - # await client.add_episode( - # name=f'Message {i}', - # episode_body=f'{message.speaker_name} ({message.role}): {message.content}', - # reference_time=message.actual_timestamp, - # source_description='Podcast Transcript', - # group_id='1', - # update_communities=True, - # ) + for i, message in enumerate(messages[14:20]): + await client.add_episode( + name=f'Message {i}', + episode_body=f'{message.speaker_name} ({message.role}): {message.content}', + reference_time=message.actual_timestamp, + source_description='Podcast Transcript', + group_id='1', + update_communities=True, + ) return diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 18f2f8a9..142c0381 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) class Edge(BaseModel, ABC): uuid: str = Field(default_factory=lambda: str(uuid4())) - group_id: str | None = Field(description='partition of the graph') + group_id: str = Field(description='partition of the graph') source_node_uuid: str target_node_uuid: str created_at: datetime @@ -131,7 +131,7 @@ class EpisodicEdge(Edge): return edges @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): + async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) @@ -270,7 +270,7 @@ class EntityEdge(Edge): return edges @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): + async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) @@ -360,7 +360,7 @@ class CommunityEdge(Edge): return edges @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): + async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 8a5eb23a..de1416b3 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -197,7 +197,7 @@ class Graphiti: self, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -233,7 +233,7 @@ class Graphiti: source_description: str, reference_time: datetime, source: EpisodeType = EpisodeType.message, - group_id: str | None = None, + group_id: str = '', uuid: str | None = None, update_communities: bool = False, ): @@ -446,7 +446,7 @@ class Graphiti: except Exception as e: raise e - async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None): + async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''): """ Process multiple episodes in bulk and update the graph. @@ -577,7 +577,7 @@ class Graphiti: self, query: str, center_node_uuid: str | None = None, - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, num_results=DEFAULT_SEARCH_LIMIT, ) -> list[EntityEdge]: """ @@ -633,7 +633,7 @@ class Graphiti: self, query: str, config: SearchConfig, - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, center_node_uuid: str | None = None, ) -> SearchResults: return await search( @@ -644,7 +644,7 @@ class Graphiti: self, query: str, center_node_uuid: str | None = None, - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, limit: int = DEFAULT_SEARCH_LIMIT, ) -> list[EntityNode]: """ diff --git a/graphiti_core/llm_client/openai_client.py b/graphiti_core/llm_client/openai_client.py index a1d9010e..f459a3f4 100644 --- a/graphiti_core/llm_client/openai_client.py +++ b/graphiti_core/llm_client/openai_client.py @@ -29,7 +29,7 @@ from .errors import RateLimitError logger = logging.getLogger(__name__) -DEFAULT_MODEL = 'gpt-4o-2024-08-06' +DEFAULT_MODEL = 'gpt-4o-mini' class OpenAIClient(LLMClient): diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index a0431a64..828a7ebb 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -70,7 +70,7 @@ class EpisodeType(Enum): class Node(BaseModel, ABC): uuid: str = Field(default_factory=lambda: str(uuid4())) name: str = Field(description='name of the node') - group_id: str | None = Field(description='partition of the graph') + group_id: str = Field(description='partition of the graph') labels: list[str] = Field(default_factory=list) created_at: datetime = Field(default_factory=lambda: datetime.now()) @@ -186,7 +186,7 @@ class EpisodicNode(Node): return episodes @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): + async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (e:Episodic) WHERE e.group_id IN $group_ids @@ -281,7 +281,7 @@ class EntityNode(Node): return nodes @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): + async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (n:Entity) WHERE n.group_id IN $group_ids @@ -374,7 +374,7 @@ class CommunityNode(Node): return communities @classmethod - async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]): + async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (n:Community) WHERE n.group_id IN $group_ids diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index a8b1c9f9..862ececd 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -15,6 +15,7 @@ limitations under the License. """ import logging +from collections import defaultdict from time import time from neo4j import AsyncDriver @@ -56,7 +57,7 @@ async def search( driver: AsyncDriver, embedder, query: str, - group_ids: list[str | None] | None, + group_ids: list[str] | None, config: SearchConfig, center_node_uuid: str | None = None, ) -> SearchResults: @@ -103,7 +104,7 @@ async def edge_search( driver: AsyncDriver, embedder, query: str, - group_ids: list[str | None] | None, + group_ids: list[str] | None, config: EdgeSearchConfig, center_node_uuid: str | None = None, limit=DEFAULT_SEARCH_LIMIT, @@ -140,14 +141,21 @@ async def edge_search( if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') - source_to_edge_uuid_map = { - edge.source_node_uuid: edge.uuid for result in search_results for edge in result - } - source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results] + # use rrf as a preliminary sort + sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results]) + sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids] + + # node distance reranking + source_to_edge_uuid_map = defaultdict(list) + for edge in sorted_results: + source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid) + + source_uuids = [edge.source_node_uuid for edge in sorted_results] reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid) - reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids] + for node_uuid in reranked_node_uuids: + reranked_uuids.extend(source_to_edge_uuid_map[node_uuid]) reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] @@ -161,7 +169,7 @@ async def node_search( driver: AsyncDriver, embedder, query: str, - group_ids: list[str | None] | None, + group_ids: list[str] | None, config: NodeSearchConfig, center_node_uuid: str | None = None, limit=DEFAULT_SEARCH_LIMIT, @@ -198,7 +206,9 @@ async def node_search( elif config.reranker == NodeReranker.node_distance: if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') - reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid) + reranked_uuids = await node_distance_reranker( + driver, rrf(search_result_uuids), center_node_uuid + ) reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids] @@ -209,7 +219,7 @@ async def community_search( driver: AsyncDriver, embedder, query: str, - group_ids: list[str | None] | None, + group_ids: list[str] | None, config: CommunitySearchConfig, limit=DEFAULT_SEARCH_LIMIT, ) -> list[CommunityNode]: diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 0cc19be5..110c3210 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -87,7 +87,7 @@ async def edge_fulltext_search( query: str, source_node_uuid: str | None, target_node_uuid: str | None, - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # fulltext search over facts @@ -95,10 +95,7 @@ async def edge_fulltext_search( CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - WHERE CASE - WHEN $group_ids IS NULL THEN n.group_id IS NULL - ELSE n.group_id IN $group_ids - END + WHERE $group_ids IS NULL OR n.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -120,10 +117,7 @@ async def edge_fulltext_search( CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) - WHERE CASE - WHEN $group_ids IS NULL THEN r.group_id IS NULL - ELSE r.group_id IN $group_ids - END + WHERE $group_ids IS NULL OR r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -144,10 +138,7 @@ async def edge_fulltext_search( CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - WHERE CASE - WHEN $group_ids IS NULL THEN r.group_id IS NULL - ELSE r.group_id IN $group_ids - END + WHERE $group_ids IS NULL OR r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -168,10 +159,7 @@ async def edge_fulltext_search( CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) - WHERE CASE - WHEN $group_ids IS NULL THEN r.group_id IS NULL - ELSE r.group_id IN $group_ids - END + WHERE $group_ids IS NULL OR r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -209,7 +197,7 @@ async def edge_similarity_search( search_vector: list[float], source_node_uuid: str | None, target_node_uuid: str | None, - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: # vector similarity search over embedded facts @@ -217,10 +205,7 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - WHERE CASE - WHEN $group_ids IS NULL THEN r.group_id IS NULL - ELSE r.group_id IN $group_ids - END + WHERE $group_ids IS NULL OR r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -242,10 +227,7 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity) - WHERE CASE - WHEN $group_ids IS NULL THEN r.group_id IS NULL - ELSE r.group_id IN $group_ids - END + WHERE $group_ids IS NULL OR r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -266,10 +248,7 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid}) - WHERE CASE - WHEN $group_ids IS NULL THEN r.group_id IS NULL - ELSE r.group_id IN $group_ids - END + WHERE $group_ids IS NULL OR r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -290,10 +269,7 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector) YIELD relationship AS rel, score MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity) - WHERE CASE - WHEN $group_ids IS NULL THEN r.group_id IS NULL - ELSE r.group_id IN $group_ids - END + WHERE $group_ids IS NULL OR r.group_id IN $group_ids RETURN r.uuid AS uuid, r.group_id AS group_id, @@ -327,7 +303,7 @@ async def edge_similarity_search( async def node_fulltext_search( driver: AsyncDriver, query: str, - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: # BM25 search to get top nodes @@ -336,10 +312,7 @@ async def node_fulltext_search( """ CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node AS n, score - WHERE CASE - WHEN $group_ids IS NULL THEN n.group_id IS NULL - ELSE n.group_id IN $group_ids - END + WHERE $group_ids IS NULL OR n.group_id IN $group_ids RETURN n.uuid AS uuid, n.group_id AS group_id, @@ -362,17 +335,16 @@ async def node_fulltext_search( async def node_similarity_search( driver: AsyncDriver, search_vector: list[float], - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: - group_ids = group_ids if group_ids is not None else [None] - # vector similarity search over entity names records, _, _ = await driver.execute_query( """ CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) YIELD node AS n, score - MATCH (n WHERE n.group_id IN $group_ids) + MATCH (n:Entity) + WHERE $group_ids IS NULL OR n.group_id IN $group_ids RETURN n.uuid As uuid, n.group_id AS group_id, @@ -394,18 +366,17 @@ async def node_similarity_search( async def community_fulltext_search( driver: AsyncDriver, query: str, - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[CommunityNode]: - group_ids = group_ids if group_ids is not None else [None] - # BM25 search to get top communities fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' records, _, _ = await driver.execute_query( """ CALL db.index.fulltext.queryNodes("community_name", $query) YIELD node AS comm, score - MATCH (comm WHERE comm.group_id in $group_ids) + MATCH (comm:Community) + WHERE $group_ids IS NULL OR comm.group_id in $group_ids RETURN comm.uuid AS uuid, comm.group_id AS group_id, @@ -428,17 +399,16 @@ async def community_fulltext_search( async def community_similarity_search( driver: AsyncDriver, search_vector: list[float], - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[CommunityNode]: - group_ids = group_ids if group_ids is not None else [None] - # vector similarity search over entity names records, _, _ = await driver.execute_query( """ CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector) YIELD node AS comm, score - MATCH (comm WHERE comm.group_id IN $group_ids) + MATCH (comm:Community) + WHERE $group_ids IS NULL OR comm.group_id IN $group_ids RETURN comm.uuid As uuid, comm.group_id AS group_id, @@ -461,7 +431,7 @@ async def hybrid_node_search( queries: list[str], embeddings: list[list[float]], driver: AsyncDriver, - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: """ @@ -503,7 +473,6 @@ async def hybrid_node_search( """ start = time() - results: list[list[EntityNode]] = list( await asyncio.gather( *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries], @@ -625,14 +594,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( - driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str + driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str ) -> list[str]: - # use rrf as a preliminary ranker - sorted_uuids = rrf(node_uuids) + # filter out node_uuid center node node uuid + filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids)) scores: dict[str, float] = {} # Find the shortest path to center node - query = Query(""" + query = Query(""" MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid}) RETURN length(p) AS score """) @@ -644,21 +613,23 @@ async def node_distance_reranker( node_uuid=uuid, center_uuid=center_node_uuid, ) - for uuid in sorted_uuids + for uuid in filtered_uuids ] ) - for uuid, result in zip(sorted_uuids, path_results): + for uuid, result in zip(filtered_uuids, path_results): records = result[0] record = records[0] if len(records) > 0 else None distance: float = record['score'] if record is not None else float('inf') - distance = 0 if uuid == center_node_uuid else distance scores[uuid] = distance # rerank on shortest distance - sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) + filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) - return sorted_uuids + # add back in filtered center uuids + filtered_uuids = [center_node_uuid] + filtered_uuids + + return filtered_uuids async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]: diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index 7a384fec..fa3046a2 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -154,7 +154,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s async def build_community( - llm_client: LLMClient, community_cluster: list[EntityNode] + llm_client: LLMClient, community_cluster: list[EntityNode] ) -> tuple[CommunityNode, list[CommunityEdge]]: summaries = [entity.summary for entity in community_cluster] length = len(summaries) @@ -168,7 +168,7 @@ async def build_community( *[ summarize_pair(llm_client, (str(left_summary), str(right_summary))) for left_summary, right_summary in zip( - summaries[: int(length / 2)], summaries[int(length / 2):] + summaries[: int(length / 2)], summaries[int(length / 2) :] ) ] ) @@ -196,7 +196,7 @@ async def build_community( async def build_communities( - driver: AsyncDriver, llm_client: LLMClient + driver: AsyncDriver, llm_client: LLMClient ) -> tuple[list[CommunityNode], list[CommunityEdge]]: community_clusters = await get_community_clusters(driver) @@ -227,7 +227,7 @@ async def remove_communities(driver: AsyncDriver): async def determine_entity_community( - driver: AsyncDriver, entity: EntityNode + driver: AsyncDriver, entity: EntityNode ) -> tuple[CommunityNode | None, bool]: # Check if the node is already part of a community records, _, _ = await driver.execute_query( @@ -288,7 +288,7 @@ async def determine_entity_community( async def update_community( - driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode + driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode ): community, is_new = await determine_entity_community(driver, entity) @@ -307,4 +307,4 @@ async def update_community( await community.generate_name_embedding(embedder) - await community.save(driver) \ No newline at end of file + await community.save(driver) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index d39594c7..e2953cb1 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -73,7 +73,7 @@ async def extract_edges( episode: EpisodicNode, nodes: list[EntityNode], previous_episodes: list[EpisodicNode], - group_id: str | None, + group_id: str = '', ) -> list[EntityEdge]: start = time() diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index 446cb889..cdc39b31 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -101,7 +101,7 @@ async def retrieve_episodes( driver: AsyncDriver, reference_time: datetime, last_n: int = EPISODE_WINDOW_LEN, - group_ids: list[str | None] | None = None, + group_ids: list[str] | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -119,7 +119,8 @@ async def retrieve_episodes( """ result = await driver.execute_query( """ - MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids + MATCH (e:Episodic) WHERE e.valid_at <= $reference_time + AND ($group_ids IS NULL) OR e.group_id in $group_ids RETURN e.content AS content, e.created_at AS created_at, e.valid_at AS valid_at, diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 488361d2..5ab04541 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -76,16 +76,18 @@ async def test_graphiti_init(): graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) await graphiti.build_communities() - edges = await graphiti.search('tania tetlow', group_ids=['1']) + edges = await graphiti.search( + 'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None + ) logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges])) - edges = await graphiti.search('issues with higher ed', group_ids=['1']) + edges = await graphiti.search('issues with higher ed', group_ids=None) logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges])) results = await graphiti._search( - 'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=['1'] + 'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=None ) pretty_results = { 'edges': [edge.fact for edge in results.edges],