mirror of
https://github.com/getzep/graphiti.git
synced 2025-12-29 08:05:02 +00:00
Community nodes (#103)
* add gds * community work * save progress * community updates * e2e communities * troubleshooting * updates * communities * remove unused import
This commit is contained in:
parent
4122d350a5
commit
c0a740ff60
@ -84,7 +84,7 @@ async def main(use_bulk: bool = True):
|
||||
for i, message in enumerate(messages[3:20])
|
||||
]
|
||||
|
||||
await client.add_episode_bulk(episodes)
|
||||
await client.add_episode_bulk(episodes, None)
|
||||
|
||||
|
||||
asyncio.run(main(False))
|
||||
|
||||
@ -41,8 +41,18 @@ class Edge(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
async def save(self, driver: AsyncDriver): ...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, driver: AsyncDriver): ...
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n)-[e {uuid: $uuid}]->(m)
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.info(f'Deleted Edge: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.uuid)
|
||||
@ -76,19 +86,6 @@ class EpisodicEdge(Edge):
|
||||
|
||||
return result
|
||||
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.info(f'Deleted Edge: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
@ -169,19 +166,6 @@ class EntityEdge(Edge):
|
||||
|
||||
return result
|
||||
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
||||
DELETE e
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.info(f'Deleted Edge: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
@ -211,6 +195,48 @@ class EntityEdge(Edge):
|
||||
return edges[0]
|
||||
|
||||
|
||||
class CommunityEdge(Edge):
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (community:Community {uuid: $community_uuid})
|
||||
MATCH (node:Entity | Community {uuid: $entity_uuid})
|
||||
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
|
||||
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
|
||||
RETURN r.uuid AS uuid""",
|
||||
community_uuid=self.source_node_uuid,
|
||||
entity_uuid=self.target_node_uuid,
|
||||
uuid=self.uuid,
|
||||
group_id=self.group_id,
|
||||
created_at=self.created_at,
|
||||
)
|
||||
|
||||
logger.info(f'Saved edge to neo4j: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
|
||||
RETURN
|
||||
e.uuid As uuid,
|
||||
e.group_id AS group_id,
|
||||
n.uuid AS source_node_uuid,
|
||||
m.uuid AS target_node_uuid,
|
||||
e.created_at AS created_at
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
edges = [get_community_edge_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Edge: {uuid}')
|
||||
|
||||
return edges[0]
|
||||
|
||||
|
||||
# Edge helpers
|
||||
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
|
||||
return EpisodicEdge(
|
||||
@ -237,3 +263,13 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
|
||||
valid_at=parse_db_date(record['valid_at']),
|
||||
invalid_at=parse_db_date(record['invalid_at']),
|
||||
)
|
||||
|
||||
|
||||
def get_community_edge_from_record(record: Any):
|
||||
return CommunityEdge(
|
||||
uuid=record['uuid'],
|
||||
group_id=record['group_id'],
|
||||
source_node_uuid=record['source_node_uuid'],
|
||||
target_node_uuid=record['target_node_uuid'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
)
|
||||
|
||||
@ -46,6 +46,10 @@ from graphiti_core.utils.bulk_utils import (
|
||||
resolve_edge_pointers,
|
||||
retrieve_previous_episodes_bulk,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.community_operations import (
|
||||
build_communities,
|
||||
remove_communities,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
extract_edges,
|
||||
resolve_extracted_edges,
|
||||
@ -526,6 +530,19 @@ class Graphiti:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def build_communities(self):
|
||||
embedder = self.llm_client.get_embedder()
|
||||
|
||||
# Clear existing communities
|
||||
await remove_communities(self.driver)
|
||||
|
||||
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
||||
|
||||
await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
|
||||
|
||||
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@ -76,8 +76,18 @@ class Node(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
async def save(self, driver: AsyncDriver): ...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, driver: AsyncDriver): ...
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.info(f'Deleted Node: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.uuid)
|
||||
@ -90,6 +100,9 @@ class Node(BaseModel, ABC):
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
|
||||
|
||||
|
||||
class EpisodicNode(Node):
|
||||
source: EpisodeType = Field(description='source type')
|
||||
@ -125,19 +138,6 @@ class EpisodicNode(Node):
|
||||
|
||||
return result
|
||||
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.info(f'Deleted Node: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
@ -161,6 +161,29 @@ class EpisodicNode(Node):
|
||||
|
||||
return episodes[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (e:Episodic) WHERE e.uuid IN $uuids
|
||||
RETURN e.content AS content,
|
||||
e.created_at AS created_at,
|
||||
e.valid_at AS valid_at,
|
||||
e.uuid AS uuid,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id
|
||||
e.source_description AS source_description,
|
||||
e.source AS source
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
|
||||
episodes = [get_episodic_node_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Nodes: {uuids}')
|
||||
|
||||
return episodes
|
||||
|
||||
|
||||
class EntityNode(Node):
|
||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||
@ -194,19 +217,6 @@ class EntityNode(Node):
|
||||
|
||||
return result
|
||||
|
||||
async def delete(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {uuid: $uuid})
|
||||
DETACH DELETE n
|
||||
""",
|
||||
uuid=self.uuid,
|
||||
)
|
||||
|
||||
logger.info(f'Deleted Node: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
@ -229,6 +239,105 @@ class EntityNode(Node):
|
||||
|
||||
return nodes[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity) WHERE n.uuid IN $uuids
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id,
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
|
||||
nodes = [get_entity_node_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Nodes: {uuids}')
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
class CommunityNode(Node):
|
||||
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
||||
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
||||
|
||||
async def save(self, driver: AsyncDriver):
|
||||
result = await driver.execute_query(
|
||||
"""
|
||||
MERGE (n:Community {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
|
||||
RETURN n.uuid AS uuid""",
|
||||
uuid=self.uuid,
|
||||
name=self.name,
|
||||
group_id=self.group_id,
|
||||
summary=self.summary,
|
||||
name_embedding=self.name_embedding,
|
||||
created_at=self.created_at,
|
||||
)
|
||||
|
||||
logger.info(f'Saved Node to neo4j: {self.uuid}')
|
||||
|
||||
return result
|
||||
|
||||
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
|
||||
start = time()
|
||||
text = self.name.replace('\n', ' ')
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
self.name_embedding = embedding[:EMBEDDING_DIM]
|
||||
end = time()
|
||||
logger.info(f'embedded {text} in {end - start} ms')
|
||||
|
||||
return embedding
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community {uuid: $uuid})
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
nodes = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Node: {uuid}')
|
||||
|
||||
return nodes[0]
|
||||
|
||||
@classmethod
|
||||
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Community) WHERE n.uuid IN $uuids
|
||||
RETURN
|
||||
n.uuid As uuid,
|
||||
n.name AS name,
|
||||
n.name_embedding AS name_embedding,
|
||||
n.group_id AS group_id
|
||||
n.created_at AS created_at,
|
||||
n.summary AS summary
|
||||
""",
|
||||
uuids=uuids,
|
||||
)
|
||||
|
||||
nodes = [get_community_node_from_record(record) for record in records]
|
||||
|
||||
logger.info(f'Found Nodes: {uuids}')
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
# Node helpers
|
||||
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
||||
@ -254,3 +363,14 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
|
||||
|
||||
def get_community_node_from_record(record: Any) -> CommunityNode:
|
||||
return CommunityNode(
|
||||
uuid=record['uuid'],
|
||||
name=record['name'],
|
||||
group_id=record['group_id'],
|
||||
name_embedding=record['name_embedding'],
|
||||
created_at=record['created_at'].to_native(),
|
||||
summary=record['summary'],
|
||||
)
|
||||
|
||||
@ -71,6 +71,9 @@ from .invalidate_edges import (
|
||||
versions as invalidate_edges_versions,
|
||||
)
|
||||
from .models import Message, PromptFunction
|
||||
from .summarize_nodes import Prompt as SummarizeNodesPrompt
|
||||
from .summarize_nodes import Versions as SummarizeNodesVersions
|
||||
from .summarize_nodes import versions as summarize_nodes_versions
|
||||
|
||||
|
||||
class PromptLibrary(Protocol):
|
||||
@ -80,6 +83,7 @@ class PromptLibrary(Protocol):
|
||||
dedupe_edges: DedupeEdgesPrompt
|
||||
invalidate_edges: InvalidateEdgesPrompt
|
||||
extract_edge_dates: ExtractEdgeDatesPrompt
|
||||
summarize_nodes: SummarizeNodesPrompt
|
||||
|
||||
|
||||
class PromptLibraryImpl(TypedDict):
|
||||
@ -89,6 +93,7 @@ class PromptLibraryImpl(TypedDict):
|
||||
dedupe_edges: DedupeEdgesVersions
|
||||
invalidate_edges: InvalidateEdgesVersions
|
||||
extract_edge_dates: ExtractEdgeDatesVersions
|
||||
summarize_nodes: SummarizeNodesVersions
|
||||
|
||||
|
||||
class VersionWrapper:
|
||||
@ -118,5 +123,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
||||
'dedupe_edges': dedupe_edges_versions,
|
||||
'invalidate_edges': invalidate_edges_versions,
|
||||
'extract_edge_dates': extract_edge_dates_versions,
|
||||
'summarize_nodes': summarize_nodes_versions,
|
||||
}
|
||||
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
|
||||
|
||||
79
graphiti_core/prompts/summarize_nodes.py
Normal file
79
graphiti_core/prompts/summarize_nodes.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
from .models import Message, PromptFunction, PromptVersion
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
summarize_pair: PromptVersion
|
||||
summary_description: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
summarize_pair: PromptFunction
|
||||
summary_description: PromptFunction
|
||||
|
||||
|
||||
def summarize_pair(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are a helpful assistant that combines summaries.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Synthesize the information from the following two summaries into a single succinct summary.
|
||||
|
||||
Summaries:
|
||||
{json.dumps(context['node_summaries'], indent=2)}
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"summary": "Summary containing the important information from both summaries"
|
||||
}}
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def summary_description(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are a helpful assistant that describes provided contents in a single sentence.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Create a short one sentence description of the summary that explains what kind of information is summarized.
|
||||
|
||||
Summary:
|
||||
{json.dumps(context['summary'], indent=2)}
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"description": "One sentence description of the provided summary"
|
||||
}}
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {'summarize_pair': summarize_pair, 'summary_description': summary_description}
|
||||
155
graphiti_core/utils/maintenance/community_operations.py
Normal file
155
graphiti_core/utils/maintenance/community_operations.py
Normal file
@ -0,0 +1,155 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
|
||||
from graphiti_core.edges import CommunityEdge
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def build_community_projection(driver: AsyncDriver) -> str:
|
||||
records, _, _ = await driver.execute_query("""
|
||||
CALL gds.graph.project("communities", "Entity",
|
||||
{RELATES_TO: {
|
||||
type: "RELATES_TO",
|
||||
orientation: "UNDIRECTED",
|
||||
properties: {weight: {property: "*", aggregation: "COUNT"}}
|
||||
}}
|
||||
)
|
||||
YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
|
||||
""")
|
||||
|
||||
return records[0]['graph']
|
||||
|
||||
|
||||
async def destroy_projection(driver: AsyncDriver, projection_name: str):
|
||||
await driver.execute_query(
|
||||
"""
|
||||
CALL gds.graph.drop($projection_name)
|
||||
""",
|
||||
projection_name=projection_name,
|
||||
)
|
||||
|
||||
|
||||
async def get_community_clusters(
|
||||
driver: AsyncDriver, projection_name: str
|
||||
) -> list[list[EntityNode]]:
|
||||
records, _, _ = await driver.execute_query("""
|
||||
CALL gds.leiden.stream("communities")
|
||||
YIELD nodeId, communityId
|
||||
RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
|
||||
""")
|
||||
community_map: dict[int, list[str]] = defaultdict(list)
|
||||
for record in records:
|
||||
community_map[record['communityId']].append(record['entity_uuid'])
|
||||
|
||||
community_clusters: list[list[EntityNode]] = list(
|
||||
await asyncio.gather(
|
||||
*[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
|
||||
)
|
||||
)
|
||||
|
||||
return community_clusters
|
||||
|
||||
|
||||
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
|
||||
# Prepare context for LLM
|
||||
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.summarize_nodes.summarize_pair(context)
|
||||
)
|
||||
|
||||
pair_summary = llm_response.get('summary', '')
|
||||
|
||||
return pair_summary
|
||||
|
||||
|
||||
async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
|
||||
context = {'summary': summary}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.summarize_nodes.summary_description(context)
|
||||
)
|
||||
|
||||
description = llm_response.get('description', '')
|
||||
|
||||
return description
|
||||
|
||||
|
||||
async def build_community(
|
||||
llm_client: LLMClient, community_cluster: list[EntityNode]
|
||||
) -> tuple[CommunityNode, list[CommunityEdge]]:
|
||||
summaries = [entity.summary for entity in community_cluster]
|
||||
length = len(summaries)
|
||||
while length > 1:
|
||||
odd_one_out: str | None = None
|
||||
if length % 2 == 1:
|
||||
odd_one_out = summaries.pop()
|
||||
length -= 1
|
||||
new_summaries: list[str] = list(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
|
||||
for left_summary, right_summary in zip(
|
||||
summaries[: int(length / 2)], summaries[int(length / 2) :]
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
if odd_one_out is not None:
|
||||
new_summaries.append(odd_one_out)
|
||||
summaries = new_summaries
|
||||
length = len(summaries)
|
||||
|
||||
summary = summaries[0]
|
||||
name = await generate_summary_description(llm_client, summary)
|
||||
now = datetime.now()
|
||||
community_node = CommunityNode(
|
||||
name=name,
|
||||
group_id=community_cluster[0].group_id,
|
||||
labels=['Community'],
|
||||
created_at=now,
|
||||
summary=summary,
|
||||
)
|
||||
community_edges = build_community_edges(community_cluster, community_node, now)
|
||||
|
||||
logger.info((community_node, community_edges))
|
||||
|
||||
return community_node, community_edges
|
||||
|
||||
|
||||
async def build_communities(
|
||||
driver: AsyncDriver, llm_client: LLMClient
|
||||
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
|
||||
projection = await build_community_projection(driver)
|
||||
community_clusters = await get_community_clusters(driver, projection)
|
||||
|
||||
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
|
||||
await asyncio.gather(
|
||||
*[build_community(llm_client, cluster) for cluster in community_clusters]
|
||||
)
|
||||
)
|
||||
|
||||
community_nodes: list[CommunityNode] = []
|
||||
community_edges: list[CommunityEdge] = []
|
||||
for community in communities:
|
||||
community_nodes.append(community[0])
|
||||
community_edges.extend(community[1])
|
||||
|
||||
await destroy_projection(driver, projection)
|
||||
return community_nodes, community_edges
|
||||
|
||||
|
||||
async def remove_communities(driver: AsyncDriver):
|
||||
await driver.execute_query("""
|
||||
MATCH (c:Community)
|
||||
DETACH DELETE c
|
||||
""")
|
||||
@ -20,9 +20,9 @@ from datetime import datetime
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
extract_edge_dates,
|
||||
@ -50,6 +50,24 @@ def build_episodic_edges(
|
||||
return edges
|
||||
|
||||
|
||||
def build_community_edges(
|
||||
entity_nodes: List[EntityNode],
|
||||
community_node: CommunityNode,
|
||||
created_at: datetime,
|
||||
) -> List[CommunityEdge]:
|
||||
edges: List[CommunityEdge] = [
|
||||
CommunityEdge(
|
||||
source_node_uuid=community_node.uuid,
|
||||
target_node_uuid=node.uuid,
|
||||
created_at=created_at,
|
||||
group_id=community_node.group_id,
|
||||
)
|
||||
for node in entity_nodes
|
||||
]
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def extract_edges(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
|
||||
@ -32,8 +32,10 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
||||
range_indices: list[LiteralString] = [
|
||||
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
|
||||
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
|
||||
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
|
||||
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
|
||||
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
|
||||
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
|
||||
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
|
||||
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
|
||||
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
|
||||
@ -51,6 +53,7 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
||||
|
||||
fulltext_indices: list[LiteralString] = [
|
||||
'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
|
||||
'CREATE FULLTEXT INDEX community_name IF NOT EXISTS FOR (n:Community) ON EACH [n.name]',
|
||||
'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
|
||||
]
|
||||
|
||||
@ -71,6 +74,14 @@ async def build_indices_and_constraints(driver: AsyncDriver):
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""",
|
||||
"""
|
||||
CREATE VECTOR INDEX community_name_embedding IF NOT EXISTS
|
||||
FOR (n:Community) ON (n.name_embedding)
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""",
|
||||
]
|
||||
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
|
||||
|
||||
|
||||
@ -73,6 +73,7 @@ def format_context(facts):
|
||||
async def test_graphiti_init():
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
await graphiti.build_communities()
|
||||
|
||||
edges = await graphiti.search('Freakenomics guest', group_ids=['1'])
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user