Group id fix (#152)

* node distance and group_ids fixed

* get all with no group_id passed

* push

* push

* remove comments

* mypy

* mypy ids

* please mypy

* trust

* last one
This commit is contained in:
Preston Rasmussen 2024-09-24 15:55:30 -04:00 committed by GitHub
parent cfeb58daba
commit 794b705664
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 93 additions and 110 deletions

View File

@ -63,28 +63,27 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages()
if not use_bulk:
for i, message in enumerate(messages[3:4]):
for i, message in enumerate(messages[3:14]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
)
# build communities
await client.build_communities()
# add additional messages to update communities
# for i, message in enumerate(messages[14:20]):
# await client.add_episode(
# name=f'Message {i}',
# episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
# reference_time=message.actual_timestamp,
# source_description='Podcast Transcript',
# group_id='1',
# update_communities=True,
# )
for i, message in enumerate(messages[14:20]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
group_id='1',
update_communities=True,
)
return

View File

@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
group_id: str | None = Field(description='partition of the graph')
group_id: str = Field(description='partition of the graph')
source_node_uuid: str
target_node_uuid: str
created_at: datetime
@ -131,7 +131,7 @@ class EpisodicEdge(Edge):
return edges
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
@ -270,7 +270,7 @@ class EntityEdge(Edge):
return edges
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
@ -360,7 +360,7 @@ class CommunityEdge(Edge):
return edges
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)

View File

@ -197,7 +197,7 @@ class Graphiti:
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
@ -233,7 +233,7 @@ class Graphiti:
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
group_id: str | None = None,
group_id: str = '',
uuid: str | None = None,
update_communities: bool = False,
):
@ -446,7 +446,7 @@ class Graphiti:
except Exception as e:
raise e
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str = ''):
"""
Process multiple episodes in bulk and update the graph.
@ -577,7 +577,7 @@ class Graphiti:
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
) -> list[EntityEdge]:
"""
@ -633,7 +633,7 @@ class Graphiti:
self,
query: str,
config: SearchConfig,
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
center_node_uuid: str | None = None,
) -> SearchResults:
return await search(
@ -644,7 +644,7 @@ class Graphiti:
self,
query: str,
center_node_uuid: str | None = None,
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
limit: int = DEFAULT_SEARCH_LIMIT,
) -> list[EntityNode]:
"""

View File

@ -29,7 +29,7 @@ from .errors import RateLimitError
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gpt-4o-2024-08-06'
DEFAULT_MODEL = 'gpt-4o-mini'
class OpenAIClient(LLMClient):

View File

@ -70,7 +70,7 @@ class EpisodeType(Enum):
class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
name: str = Field(description='name of the node')
group_id: str | None = Field(description='partition of the graph')
group_id: str = Field(description='partition of the graph')
labels: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now())
@ -186,7 +186,7 @@ class EpisodicNode(Node):
return episodes
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
@ -281,7 +281,7 @@ class EntityNode(Node):
return nodes
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity) WHERE n.group_id IN $group_ids
@ -374,7 +374,7 @@ class CommunityNode(Node):
return communities
@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str | None]):
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community) WHERE n.group_id IN $group_ids

View File

@ -15,6 +15,7 @@ limitations under the License.
"""
import logging
from collections import defaultdict
from time import time
from neo4j import AsyncDriver
@ -56,7 +57,7 @@ async def search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
group_ids: list[str] | None,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
@ -103,7 +104,7 @@ async def edge_search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
group_ids: list[str] | None,
config: EdgeSearchConfig,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
@ -140,14 +141,21 @@ async def edge_search(
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
source_to_edge_uuid_map = {
edge.source_node_uuid: edge.uuid for result in search_results for edge in result
}
source_uuids = [[edge.source_node_uuid for edge in result] for result in search_results]
# use rrf as a preliminary sort
sorted_result_uuids = rrf([[edge.uuid for edge in result] for result in search_results])
sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
# node distance reranking
source_to_edge_uuid_map = defaultdict(list)
for edge in sorted_results:
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
source_uuids = [edge.source_node_uuid for edge in sorted_results]
reranked_node_uuids = await node_distance_reranker(driver, source_uuids, center_node_uuid)
reranked_uuids = [source_to_edge_uuid_map[node_uuid] for node_uuid in reranked_node_uuids]
for node_uuid in reranked_node_uuids:
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
@ -161,7 +169,7 @@ async def node_search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
group_ids: list[str] | None,
config: NodeSearchConfig,
center_node_uuid: str | None = None,
limit=DEFAULT_SEARCH_LIMIT,
@ -198,7 +206,9 @@ async def node_search(
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid)
reranked_uuids = await node_distance_reranker(
driver, rrf(search_result_uuids), center_node_uuid
)
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
@ -209,7 +219,7 @@ async def community_search(
driver: AsyncDriver,
embedder,
query: str,
group_ids: list[str | None] | None,
group_ids: list[str] | None,
config: CommunitySearchConfig,
limit=DEFAULT_SEARCH_LIMIT,
) -> list[CommunityNode]:

View File

@ -87,7 +87,7 @@ async def edge_fulltext_search(
query: str,
source_node_uuid: str | None,
target_node_uuid: str | None,
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# fulltext search over facts
@ -95,10 +95,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE CASE
WHEN $group_ids IS NULL THEN n.group_id IS NULL
ELSE n.group_id IN $group_ids
END
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -120,10 +117,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE CASE
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -144,10 +138,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE CASE
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -168,10 +159,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE CASE
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -209,7 +197,7 @@ async def edge_similarity_search(
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
@ -217,10 +205,7 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE CASE
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -242,10 +227,7 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
WHERE CASE
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -266,10 +248,7 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
WHERE CASE
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -290,10 +269,7 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
YIELD relationship AS rel, score
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
WHERE CASE
WHEN $group_ids IS NULL THEN r.group_id IS NULL
ELSE r.group_id IN $group_ids
END
WHERE $group_ids IS NULL OR r.group_id IN $group_ids
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
@ -327,7 +303,7 @@ async def edge_similarity_search(
async def node_fulltext_search(
driver: AsyncDriver,
query: str,
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
# BM25 search to get top nodes
@ -336,10 +312,7 @@ async def node_fulltext_search(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query)
YIELD node AS n, score
WHERE CASE
WHEN $group_ids IS NULL THEN n.group_id IS NULL
ELSE n.group_id IN $group_ids
END
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
RETURN
n.uuid AS uuid,
n.group_id AS group_id,
@ -362,17 +335,16 @@ async def node_fulltext_search(
async def node_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
MATCH (n WHERE n.group_id IN $group_ids)
MATCH (n:Entity)
WHERE $group_ids IS NULL OR n.group_id IN $group_ids
RETURN
n.uuid As uuid,
n.group_id AS group_id,
@ -394,18 +366,17 @@ async def node_similarity_search(
async def community_fulltext_search(
driver: AsyncDriver,
query: str,
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
group_ids = group_ids if group_ids is not None else [None]
# BM25 search to get top communities
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("community_name", $query)
YIELD node AS comm, score
MATCH (comm WHERE comm.group_id in $group_ids)
MATCH (comm:Community)
WHERE $group_ids IS NULL OR comm.group_id in $group_ids
RETURN
comm.uuid AS uuid,
comm.group_id AS group_id,
@ -428,17 +399,16 @@ async def community_fulltext_search(
async def community_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
group_ids = group_ids if group_ids is not None else [None]
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
"""
CALL db.index.vector.queryNodes("community_name_embedding", $limit, $search_vector)
YIELD node AS comm, score
MATCH (comm WHERE comm.group_id IN $group_ids)
MATCH (comm:Community)
WHERE $group_ids IS NULL OR comm.group_id IN $group_ids
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
@ -461,7 +431,7 @@ async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
@ -503,7 +473,6 @@ async def hybrid_node_search(
"""
start = time()
results: list[list[EntityNode]] = list(
await asyncio.gather(
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
@ -625,14 +594,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
async def node_distance_reranker(
driver: AsyncDriver, node_uuids: list[list[str]], center_node_uuid: str
driver: AsyncDriver, node_uuids: list[str], center_node_uuid: str
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(node_uuids)
# filter out node_uuid center node node uuid
filtered_uuids = list(filter(lambda uuid: uuid != center_node_uuid, node_uuids))
scores: dict[str, float] = {}
# Find the shortest path to center node
query = Query("""
query = Query("""
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: $node_uuid})
RETURN length(p) AS score
""")
@ -644,21 +613,23 @@ async def node_distance_reranker(
node_uuid=uuid,
center_uuid=center_node_uuid,
)
for uuid in sorted_uuids
for uuid in filtered_uuids
]
)
for uuid, result in zip(sorted_uuids, path_results):
for uuid, result in zip(filtered_uuids, path_results):
records = result[0]
record = records[0] if len(records) > 0 else None
distance: float = record['score'] if record is not None else float('inf')
distance = 0 if uuid == center_node_uuid else distance
scores[uuid] = distance
# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return sorted_uuids
# add back in filtered center uuids
filtered_uuids = [center_node_uuid] + filtered_uuids
return filtered_uuids
async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]:

View File

@ -154,7 +154,7 @@ async def generate_summary_description(llm_client: LLMClient, summary: str) -> s
async def build_community(
llm_client: LLMClient, community_cluster: list[EntityNode]
llm_client: LLMClient, community_cluster: list[EntityNode]
) -> tuple[CommunityNode, list[CommunityEdge]]:
summaries = [entity.summary for entity in community_cluster]
length = len(summaries)
@ -168,7 +168,7 @@ async def build_community(
*[
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
for left_summary, right_summary in zip(
summaries[: int(length / 2)], summaries[int(length / 2):]
summaries[: int(length / 2)], summaries[int(length / 2) :]
)
]
)
@ -196,7 +196,7 @@ async def build_community(
async def build_communities(
driver: AsyncDriver, llm_client: LLMClient
driver: AsyncDriver, llm_client: LLMClient
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
community_clusters = await get_community_clusters(driver)
@ -227,7 +227,7 @@ async def remove_communities(driver: AsyncDriver):
async def determine_entity_community(
driver: AsyncDriver, entity: EntityNode
driver: AsyncDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]:
# Check if the node is already part of a community
records, _, _ = await driver.execute_query(
@ -288,7 +288,7 @@ async def determine_entity_community(
async def update_community(
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
driver: AsyncDriver, llm_client: LLMClient, embedder, entity: EntityNode
):
community, is_new = await determine_entity_community(driver, entity)
@ -307,4 +307,4 @@ async def update_community(
await community.generate_name_embedding(embedder)
await community.save(driver)
await community.save(driver)

View File

@ -73,7 +73,7 @@ async def extract_edges(
episode: EpisodicNode,
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
group_id: str | None,
group_id: str = '',
) -> list[EntityEdge]:
start = time()

View File

@ -101,7 +101,7 @@ async def retrieve_episodes(
driver: AsyncDriver,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
group_ids: list[str | None] | None = None,
group_ids: list[str] | None = None,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
@ -119,7 +119,8 @@ async def retrieve_episodes(
"""
result = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time AND e.group_id in $group_ids
MATCH (e:Episodic) WHERE e.valid_at <= $reference_time
AND ($group_ids IS NULL) OR e.group_id in $group_ids
RETURN e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,

View File

@ -76,16 +76,18 @@ async def test_graphiti_init():
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
await graphiti.build_communities()
edges = await graphiti.search('tania tetlow', group_ids=['1'])
edges = await graphiti.search(
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
)
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
edges = await graphiti.search('issues with higher ed', group_ids=['1'])
edges = await graphiti.search('issues with higher ed', group_ids=None)
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
results = await graphiti._search(
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=['1']
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=None
)
pretty_results = {
'edges': [edge.fact for edge in results.edges],