2024-09-11 12:06:35 -04:00
|
|
|
import asyncio
|
|
|
|
|
import logging
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
2024-09-23 11:05:44 -04:00
|
|
|
from pydantic import BaseModel
|
2024-09-11 12:06:35 -04:00
|
|
|
|
2025-06-13 12:06:57 -04:00
|
|
|
from graphiti_core.driver.driver import GraphDriver
|
2024-09-11 12:06:35 -04:00
|
|
|
from graphiti_core.edges import CommunityEdge
|
2024-09-27 12:47:04 -04:00
|
|
|
from graphiti_core.embedder import EmbedderClient
|
2025-07-10 14:25:39 -07:00
|
|
|
from graphiti_core.helpers import semaphore_gather
|
2024-09-11 12:06:35 -04:00
|
|
|
from graphiti_core.llm_client import LLMClient
|
Gemini support (#324)
* first cut
* Update dependencies and enhance README for optional LLM providers
- Bump aiohttp version from 3.11.14 to 3.11.16
- Update yarl version from 1.18.3 to 1.19.0
- Modify pyproject.toml to include optional extras for Anthropic, Groq, and Google Gemini
- Revise README.md to reflect new optional LLM provider installation instructions and clarify API key requirements
* Remove deprecated packages from poetry.lock and update content hash
- Removed cachetools, google-auth, google-genai, pyasn1, pyasn1-modules, rsa, and websockets from the lock file.
- Added new extras for anthropic, google-genai, and groq.
- Updated content hash to reflect changes.
* Refactor import paths for GeminiClient in README and __init__.py
- Updated import statement in README.md to reflect the new module structure for GeminiClient.
- Removed GeminiClient from the __all__ list in __init__.py as it is no longer directly imported.
* Refactor import paths for GeminiEmbedder in README and __init__.py
- Updated import statement in README.md to reflect the new module structure for GeminiEmbedder.
- Removed GeminiEmbedder and GeminiEmbedderConfig from the __all__ list in __init__.py as they are no longer directly imported.
2025-04-06 09:27:04 -07:00
|
|
|
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
|
2024-09-11 12:06:35 -04:00
|
|
|
from graphiti_core.prompts import prompt_library
|
2024-12-05 07:03:18 -08:00
|
|
|
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
|
2024-12-09 10:36:04 -08:00
|
|
|
from graphiti_core.utils.datetime_utils import utc_now
|
2024-09-11 12:06:35 -04:00
|
|
|
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
|
|
|
|
|
2024-09-22 13:38:54 -07:00
|
|
|
MAX_COMMUNITY_BUILD_CONCURRENCY = 10
|
|
|
|
|
|
2024-09-11 12:06:35 -04:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
2024-09-23 11:05:44 -04:00
|
|
|
class Neighbor(BaseModel):
|
|
|
|
|
node_uuid: str
|
|
|
|
|
edge_count: int
|
|
|
|
|
|
|
|
|
|
|
2024-10-08 13:55:10 -04:00
|
|
|
async def get_community_clusters(
|
2025-06-13 12:06:57 -04:00
|
|
|
driver: GraphDriver, group_ids: list[str] | None
|
2024-10-08 13:55:10 -04:00
|
|
|
) -> list[list[EntityNode]]:
|
2024-09-23 11:05:44 -04:00
|
|
|
community_clusters: list[list[EntityNode]] = []
|
2024-09-11 12:06:35 -04:00
|
|
|
|
2024-10-08 13:55:10 -04:00
|
|
|
if group_ids is None:
|
2024-10-21 12:33:32 -04:00
|
|
|
group_id_values, _, _ = await driver.execute_query(
|
|
|
|
|
"""
|
2024-10-08 13:55:10 -04:00
|
|
|
MATCH (n:Entity WHERE n.group_id IS NOT NULL)
|
|
|
|
|
RETURN
|
|
|
|
|
collect(DISTINCT n.group_id) AS group_ids
|
2024-10-21 12:33:32 -04:00
|
|
|
""",
|
|
|
|
|
)
|
2024-10-08 13:55:10 -04:00
|
|
|
|
2025-06-30 12:04:21 -07:00
|
|
|
group_ids = group_id_values[0]['group_ids'] if group_id_values else []
|
2024-09-11 12:06:35 -04:00
|
|
|
|
2024-09-23 11:05:44 -04:00
|
|
|
for group_id in group_ids:
|
|
|
|
|
projection: dict[str, list[Neighbor]] = {}
|
|
|
|
|
nodes = await EntityNode.get_by_group_ids(driver, [group_id])
|
|
|
|
|
for node in nodes:
|
|
|
|
|
records, _, _ = await driver.execute_query(
|
|
|
|
|
"""
|
|
|
|
|
MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[r:RELATES_TO]-(m: Entity {group_id: $group_id})
|
|
|
|
|
WITH count(r) AS count, m.uuid AS uuid
|
|
|
|
|
RETURN
|
|
|
|
|
uuid,
|
|
|
|
|
count
|
|
|
|
|
""",
|
|
|
|
|
uuid=node.uuid,
|
|
|
|
|
group_id=group_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
projection[node.uuid] = [
|
|
|
|
|
Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
cluster_uuids = label_propagation(projection)
|
|
|
|
|
|
|
|
|
|
community_clusters.extend(
|
|
|
|
|
list(
|
2024-12-17 13:08:18 -05:00
|
|
|
await semaphore_gather(
|
2024-09-23 11:05:44 -04:00
|
|
|
*[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
|
|
|
|
|
)
|
|
|
|
|
)
|
2024-09-11 12:06:35 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return community_clusters
|
|
|
|
|
|
|
|
|
|
|
2024-09-23 11:05:44 -04:00
|
|
|
def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
|
|
|
|
|
# Implement the label propagation community detection algorithm.
|
|
|
|
|
# 1. Start with each node being assigned its own community
|
|
|
|
|
# 2. Each node will take on the community of the plurality of its neighbors
|
|
|
|
|
# 3. Ties are broken by going to the largest community
|
|
|
|
|
# 4. Continue until no communities change during propagation
|
|
|
|
|
|
|
|
|
|
community_map = {uuid: i for i, uuid in enumerate(projection.keys())}
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
no_change = True
|
|
|
|
|
new_community_map: dict[str, int] = {}
|
|
|
|
|
|
|
|
|
|
for uuid, neighbors in projection.items():
|
|
|
|
|
curr_community = community_map[uuid]
|
|
|
|
|
|
|
|
|
|
community_candidates: dict[int, int] = defaultdict(int)
|
|
|
|
|
for neighbor in neighbors:
|
|
|
|
|
community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
|
|
|
|
|
community_lst = [
|
|
|
|
|
(count, community) for community, count in community_candidates.items()
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
community_lst.sort(reverse=True)
|
2025-04-18 00:09:12 -04:00
|
|
|
candidate_rank, community_candidate = community_lst[0] if community_lst else (0, -1)
|
2025-04-18 03:22:24 +08:00
|
|
|
if community_candidate != -1 and candidate_rank > 1:
|
|
|
|
|
new_community = community_candidate
|
|
|
|
|
else:
|
|
|
|
|
new_community = max(community_candidate, curr_community)
|
2024-09-23 11:05:44 -04:00
|
|
|
|
|
|
|
|
new_community_map[uuid] = new_community
|
|
|
|
|
|
|
|
|
|
if new_community != curr_community:
|
|
|
|
|
no_change = False
|
|
|
|
|
|
|
|
|
|
if no_change:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
community_map = new_community_map
|
|
|
|
|
|
|
|
|
|
community_cluster_map = defaultdict(list)
|
|
|
|
|
for uuid, community in community_map.items():
|
|
|
|
|
community_cluster_map[community].append(uuid)
|
|
|
|
|
|
|
|
|
|
clusters = [cluster for cluster in community_cluster_map.values()]
|
|
|
|
|
return clusters
|
|
|
|
|
|
|
|
|
|
|
2024-09-11 12:06:35 -04:00
|
|
|
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
|
|
|
|
|
# Prepare context for LLM
|
|
|
|
|
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
|
|
|
|
|
|
|
|
|
|
llm_response = await llm_client.generate_response(
|
2024-12-05 07:03:18 -08:00
|
|
|
prompt_library.summarize_nodes.summarize_pair(context), response_model=Summary
|
2024-09-11 12:06:35 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
pair_summary = llm_response.get('summary', '')
|
|
|
|
|
|
|
|
|
|
return pair_summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
|
|
|
|
|
context = {'summary': summary}
|
|
|
|
|
|
|
|
|
|
llm_response = await llm_client.generate_response(
|
2024-12-05 07:03:18 -08:00
|
|
|
prompt_library.summarize_nodes.summary_description(context),
|
|
|
|
|
response_model=SummaryDescription,
|
2024-09-11 12:06:35 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
description = llm_response.get('description', '')
|
|
|
|
|
|
|
|
|
|
return description
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def build_community(
|
2024-09-24 15:55:30 -04:00
|
|
|
llm_client: LLMClient, community_cluster: list[EntityNode]
|
2024-09-11 12:06:35 -04:00
|
|
|
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
|
|
|
|
summaries = [entity.summary for entity in community_cluster]
|
|
|
|
|
length = len(summaries)
|
|
|
|
|
while length > 1:
|
|
|
|
|
odd_one_out: str | None = None
|
|
|
|
|
if length % 2 == 1:
|
|
|
|
|
odd_one_out = summaries.pop()
|
|
|
|
|
length -= 1
|
|
|
|
|
new_summaries: list[str] = list(
|
2024-12-17 13:08:18 -05:00
|
|
|
await semaphore_gather(
|
2024-09-11 12:06:35 -04:00
|
|
|
*[
|
|
|
|
|
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
|
|
|
|
for left_summary, right_summary in zip(
|
2025-04-08 20:47:38 -07:00
|
|
|
summaries[: int(length / 2)], summaries[int(length / 2) :], strict=False
|
2024-09-11 12:06:35 -04:00
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if odd_one_out is not None:
|
|
|
|
|
new_summaries.append(odd_one_out)
|
|
|
|
|
summaries = new_summaries
|
|
|
|
|
length = len(summaries)
|
|
|
|
|
|
|
|
|
|
summary = summaries[0]
|
|
|
|
|
name = await generate_summary_description(llm_client, summary)
|
2024-12-09 10:36:04 -08:00
|
|
|
now = utc_now()
|
2024-09-11 12:06:35 -04:00
|
|
|
community_node = CommunityNode(
|
|
|
|
|
name=name,
|
|
|
|
|
group_id=community_cluster[0].group_id,
|
|
|
|
|
labels=['Community'],
|
|
|
|
|
created_at=now,
|
|
|
|
|
summary=summary,
|
|
|
|
|
)
|
|
|
|
|
community_edges = build_community_edges(community_cluster, community_node, now)
|
|
|
|
|
|
2024-10-11 16:38:56 -04:00
|
|
|
logger.debug((community_node, community_edges))
|
2024-09-11 12:06:35 -04:00
|
|
|
|
|
|
|
|
return community_node, community_edges
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def build_communities(
|
2025-06-13 12:06:57 -04:00
|
|
|
driver: GraphDriver, llm_client: LLMClient, group_ids: list[str] | None
|
2024-09-11 12:06:35 -04:00
|
|
|
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
2024-10-08 13:55:10 -04:00
|
|
|
community_clusters = await get_community_clusters(driver, group_ids)
|
2024-09-11 12:06:35 -04:00
|
|
|
|
2024-09-22 13:38:54 -07:00
|
|
|
semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
|
|
|
|
|
|
|
|
|
|
async def limited_build_community(cluster):
|
|
|
|
|
async with semaphore:
|
|
|
|
|
return await build_community(llm_client, cluster)
|
|
|
|
|
|
2024-09-11 12:06:35 -04:00
|
|
|
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
2024-12-17 13:08:18 -05:00
|
|
|
await semaphore_gather(
|
|
|
|
|
*[limited_build_community(cluster) for cluster in community_clusters]
|
|
|
|
|
)
|
2024-09-11 12:06:35 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
community_nodes: list[CommunityNode] = []
|
|
|
|
|
community_edges: list[CommunityEdge] = []
|
|
|
|
|
for community in communities:
|
|
|
|
|
community_nodes.append(community[0])
|
|
|
|
|
community_edges.extend(community[1])
|
|
|
|
|
|
|
|
|
|
return community_nodes, community_edges
|
|
|
|
|
|
|
|
|
|
|
2025-06-13 12:06:57 -04:00
|
|
|
async def remove_communities(driver: GraphDriver):
|
2024-10-21 12:33:32 -04:00
|
|
|
await driver.execute_query(
|
|
|
|
|
"""
|
2024-09-11 12:06:35 -04:00
|
|
|
MATCH (c:Community)
|
|
|
|
|
DETACH DELETE c
|
2024-10-21 12:33:32 -04:00
|
|
|
""",
|
|
|
|
|
)
|
2024-09-18 11:37:34 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def determine_entity_community(
|
2025-06-13 12:06:57 -04:00
|
|
|
driver: GraphDriver, entity: EntityNode
|
2024-09-18 11:37:34 -04:00
|
|
|
) -> tuple[CommunityNode | None, bool]:
|
|
|
|
|
# Check if the node is already part of a community
|
2025-06-13 14:12:09 -04:00
|
|
|
records, _, _ = await driver.execute_query(
|
2024-09-18 11:37:34 -04:00
|
|
|
"""
|
|
|
|
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
|
|
|
|
|
RETURN
|
|
|
|
|
c.uuid As uuid,
|
|
|
|
|
c.name AS name,
|
|
|
|
|
c.group_id AS group_id,
|
|
|
|
|
c.created_at AS created_at,
|
|
|
|
|
c.summary AS summary
|
|
|
|
|
""",
|
|
|
|
|
entity_uuid=entity.uuid,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if len(records) > 0:
|
|
|
|
|
return get_community_node_from_record(records[0]), False
|
|
|
|
|
|
|
|
|
|
# If the node has no community, add it to the mode community of surrounding entities
|
2025-06-13 14:12:09 -04:00
|
|
|
records, _, _ = await driver.execute_query(
|
2024-09-18 11:37:34 -04:00
|
|
|
"""
|
|
|
|
|
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
|
|
|
|
|
RETURN
|
|
|
|
|
c.uuid As uuid,
|
|
|
|
|
c.name AS name,
|
|
|
|
|
c.group_id AS group_id,
|
|
|
|
|
c.created_at AS created_at,
|
|
|
|
|
c.summary AS summary
|
|
|
|
|
""",
|
|
|
|
|
entity_uuid=entity.uuid,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
communities: list[CommunityNode] = [
|
|
|
|
|
get_community_node_from_record(record) for record in records
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
community_map: dict[str, int] = defaultdict(int)
|
|
|
|
|
for community in communities:
|
|
|
|
|
community_map[community.uuid] += 1
|
|
|
|
|
|
|
|
|
|
community_uuid = None
|
|
|
|
|
max_count = 0
|
|
|
|
|
for uuid, count in community_map.items():
|
|
|
|
|
if count > max_count:
|
|
|
|
|
community_uuid = uuid
|
|
|
|
|
max_count = count
|
|
|
|
|
|
|
|
|
|
if max_count == 0:
|
|
|
|
|
return None, False
|
|
|
|
|
|
|
|
|
|
for community in communities:
|
|
|
|
|
if community.uuid == community_uuid:
|
|
|
|
|
return community, True
|
|
|
|
|
|
|
|
|
|
return None, False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def update_community(
|
2025-06-13 12:06:57 -04:00
|
|
|
driver: GraphDriver, llm_client: LLMClient, embedder: EmbedderClient, entity: EntityNode
|
2024-09-18 11:37:34 -04:00
|
|
|
):
|
|
|
|
|
community, is_new = await determine_entity_community(driver, entity)
|
|
|
|
|
|
|
|
|
|
if community is None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
|
|
|
|
|
new_name = await generate_summary_description(llm_client, new_summary)
|
|
|
|
|
|
|
|
|
|
community.summary = new_summary
|
|
|
|
|
community.name = new_name
|
|
|
|
|
|
|
|
|
|
if is_new:
|
2024-12-09 10:36:04 -08:00
|
|
|
community_edge = (build_community_edges([entity], community, utc_now()))[0]
|
2024-09-18 11:37:34 -04:00
|
|
|
await community_edge.save(driver)
|
|
|
|
|
|
2024-09-27 12:47:04 -04:00
|
|
|
await community.generate_name_embedding(embedder)
|
2024-09-18 11:37:34 -04:00
|
|
|
|
2024-09-24 15:55:30 -04:00
|
|
|
await community.save(driver)
|