Community nodes (#103)

* add gds

* community work

* save progress

* community updates

* e2e communities

* troubleshooting

* updates

* communities

* remove unused import
This commit is contained in:
Preston Rasmussen 2024-09-11 12:06:35 -04:00 committed by GitHub
parent 4122d350a5
commit c0a740ff60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 502 additions and 59 deletions

View File

@ -84,7 +84,7 @@ async def main(use_bulk: bool = True):
for i, message in enumerate(messages[3:20])
]
await client.add_episode_bulk(episodes)
await client.add_episode_bulk(episodes, None)
asyncio.run(main(False))

View File

@ -41,8 +41,18 @@ class Edge(BaseModel, ABC):
@abstractmethod
async def save(self, driver: AsyncDriver): ...
@abstractmethod
async def delete(self, driver: AsyncDriver): ...
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n)-[e {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
)
logger.info(f'Deleted Edge: {self.uuid}')
return result
def __hash__(self):
return hash(self.uuid)
@ -76,19 +86,6 @@ class EpisodicEdge(Edge):
return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
DELETE e
""",
uuid=self.uuid,
)
logger.info(f'Deleted Edge: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
@ -169,19 +166,6 @@ class EntityEdge(Edge):
return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
DELETE e
""",
uuid=self.uuid,
)
logger.info(f'Deleted Edge: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
@ -211,6 +195,48 @@ class EntityEdge(Edge):
return edges[0]
class CommunityEdge(Edge):
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (community:Community {uuid: $community_uuid})
MATCH (node:Entity | Community {uuid: $entity_uuid})
MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
RETURN r.uuid AS uuid""",
community_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
)
logger.info(f'Saved edge to neo4j: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
""",
uuid=uuid,
)
edges = [get_community_edge_from_record(record) for record in records]
logger.info(f'Found Edge: {uuid}')
return edges[0]
# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
return EpisodicEdge(
@ -237,3 +263,13 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
)
def get_community_edge_from_record(record: Any):
return CommunityEdge(
uuid=record['uuid'],
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=record['created_at'].to_native(),
)

View File

@ -46,6 +46,10 @@ from graphiti_core.utils.bulk_utils import (
resolve_edge_pointers,
retrieve_previous_episodes_bulk,
)
from graphiti_core.utils.maintenance.community_operations import (
build_communities,
remove_communities,
)
from graphiti_core.utils.maintenance.edge_operations import (
extract_edges,
resolve_extracted_edges,
@ -526,6 +530,19 @@ class Graphiti:
except Exception as e:
raise e
async def build_communities(self):
embedder = self.llm_client.get_embedder()
# Clear existing communities
await remove_communities(self.driver)
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
async def search(
self,
query: str,

View File

@ -76,8 +76,18 @@ class Node(BaseModel, ABC):
@abstractmethod
async def save(self, driver: AsyncDriver): ...
@abstractmethod
async def delete(self, driver: AsyncDriver): ...
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.info(f'Deleted Node: {self.uuid}')
return result
def __hash__(self):
return hash(self.uuid)
@ -90,6 +100,9 @@ class Node(BaseModel, ABC):
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
class EpisodicNode(Node):
source: EpisodeType = Field(description='source type')
@ -125,19 +138,6 @@ class EpisodicNode(Node):
return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Episodic {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.info(f'Deleted Node: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
@ -161,6 +161,29 @@ class EpisodicNode(Node):
return episodes[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.uuid IN $uuids
RETURN e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id
e.source_description AS source_description,
e.source AS source
""",
uuids=uuids,
)
episodes = [get_episodic_node_from_record(record) for record in records]
logger.info(f'Found Nodes: {uuids}')
return episodes
class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
@ -194,19 +217,6 @@ class EntityNode(Node):
return result
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.info(f'Deleted Node: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
@ -229,6 +239,105 @@ class EntityNode(Node):
return nodes[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity) WHERE n.uuid IN $uuids
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
""",
uuids=uuids,
)
nodes = [get_entity_node_from_record(record) for record in records]
logger.info(f'Found Nodes: {uuids}')
return nodes
class CommunityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='region summary of member nodes', default_factory=str)
async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
RETURN n.uuid AS uuid""",
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
)
logger.info(f'Saved Node to neo4j: {self.uuid}')
return result
async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
start = time()
text = self.name.replace('\n', ' ')
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.name_embedding = embedding[:EMBEDDING_DIM]
end = time()
logger.info(f'embedded {text} in {end - start} ms')
return embedding
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community {uuid: $uuid})
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id
n.created_at AS created_at,
n.summary AS summary
""",
uuid=uuid,
)
nodes = [get_community_node_from_record(record) for record in records]
logger.info(f'Found Node: {uuid}')
return nodes[0]
@classmethod
async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community) WHERE n.uuid IN $uuids
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id
n.created_at AS created_at,
n.summary AS summary
""",
uuids=uuids,
)
nodes = [get_community_node_from_record(record) for record in records]
logger.info(f'Found Nodes: {uuids}')
return nodes
# Node helpers
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
@ -254,3 +363,14 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
created_at=record['created_at'].to_native(),
summary=record['summary'],
)
def get_community_node_from_record(record: Any) -> CommunityNode:
return CommunityNode(
uuid=record['uuid'],
name=record['name'],
group_id=record['group_id'],
name_embedding=record['name_embedding'],
created_at=record['created_at'].to_native(),
summary=record['summary'],
)

View File

@ -71,6 +71,9 @@ from .invalidate_edges import (
versions as invalidate_edges_versions,
)
from .models import Message, PromptFunction
from .summarize_nodes import Prompt as SummarizeNodesPrompt
from .summarize_nodes import Versions as SummarizeNodesVersions
from .summarize_nodes import versions as summarize_nodes_versions
class PromptLibrary(Protocol):
@ -80,6 +83,7 @@ class PromptLibrary(Protocol):
dedupe_edges: DedupeEdgesPrompt
invalidate_edges: InvalidateEdgesPrompt
extract_edge_dates: ExtractEdgeDatesPrompt
summarize_nodes: SummarizeNodesPrompt
class PromptLibraryImpl(TypedDict):
@ -89,6 +93,7 @@ class PromptLibraryImpl(TypedDict):
dedupe_edges: DedupeEdgesVersions
invalidate_edges: InvalidateEdgesVersions
extract_edge_dates: ExtractEdgeDatesVersions
summarize_nodes: SummarizeNodesVersions
class VersionWrapper:
@ -118,5 +123,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
'dedupe_edges': dedupe_edges_versions,
'invalidate_edges': invalidate_edges_versions,
'extract_edge_dates': extract_edge_dates_versions,
'summarize_nodes': summarize_nodes_versions,
}
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]

View File

@ -0,0 +1,79 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from typing import Any, Protocol, TypedDict
from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol):
summarize_pair: PromptVersion
summary_description: PromptVersion
class Versions(TypedDict):
summarize_pair: PromptFunction
summary_description: PromptFunction
def summarize_pair(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that combines summaries.',
),
Message(
role='user',
content=f"""
Synthesize the information from the following two summaries into a single succinct summary.
Summaries:
{json.dumps(context['node_summaries'], indent=2)}
Respond with a JSON object in the following format:
{{
"summary": "Summary containing the important information from both summaries"
}}
""",
),
]
def summary_description(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that describes provided contents in a single sentence.',
),
Message(
role='user',
content=f"""
Create a short one sentence description of the summary that explains what kind of information is summarized.
Summary:
{json.dumps(context['summary'], indent=2)}
Respond with a JSON object in the following format:
{{
"description": "One sentence description of the provided summary"
}}
""",
),
]
versions: Versions = {'summarize_pair': summarize_pair, 'summary_description': summary_description}

View File

@ -0,0 +1,155 @@
import asyncio
import logging
from collections import defaultdict
from datetime import datetime
from neo4j import AsyncDriver
from graphiti_core.edges import CommunityEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.prompts import prompt_library
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
logger = logging.getLogger(__name__)
async def build_community_projection(driver: AsyncDriver) -> str:
records, _, _ = await driver.execute_query("""
CALL gds.graph.project("communities", "Entity",
{RELATES_TO: {
type: "RELATES_TO",
orientation: "UNDIRECTED",
properties: {weight: {property: "*", aggregation: "COUNT"}}
}}
)
YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
""")
return records[0]['graph']
async def destroy_projection(driver: AsyncDriver, projection_name: str):
await driver.execute_query(
"""
CALL gds.graph.drop($projection_name)
""",
projection_name=projection_name,
)
async def get_community_clusters(
driver: AsyncDriver, projection_name: str
) -> list[list[EntityNode]]:
records, _, _ = await driver.execute_query("""
CALL gds.leiden.stream("communities")
YIELD nodeId, communityId
RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
""")
community_map: dict[int, list[str]] = defaultdict(list)
for record in records:
community_map[record['communityId']].append(record['entity_uuid'])
community_clusters: list[list[EntityNode]] = list(
await asyncio.gather(
*[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
)
)
return community_clusters
async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
# Prepare context for LLM
context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
llm_response = await llm_client.generate_response(
prompt_library.summarize_nodes.summarize_pair(context)
)
pair_summary = llm_response.get('summary', '')
return pair_summary
async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
context = {'summary': summary}
llm_response = await llm_client.generate_response(
prompt_library.summarize_nodes.summary_description(context)
)
description = llm_response.get('description', '')
return description
async def build_community(
llm_client: LLMClient, community_cluster: list[EntityNode]
) -> tuple[CommunityNode, list[CommunityEdge]]:
summaries = [entity.summary for entity in community_cluster]
length = len(summaries)
while length > 1:
odd_one_out: str | None = None
if length % 2 == 1:
odd_one_out = summaries.pop()
length -= 1
new_summaries: list[str] = list(
await asyncio.gather(
*[
summarize_pair(llm_client, (str(left_summary), str(right_summary)))
for left_summary, right_summary in zip(
summaries[: int(length / 2)], summaries[int(length / 2) :]
)
]
)
)
if odd_one_out is not None:
new_summaries.append(odd_one_out)
summaries = new_summaries
length = len(summaries)
summary = summaries[0]
name = await generate_summary_description(llm_client, summary)
now = datetime.now()
community_node = CommunityNode(
name=name,
group_id=community_cluster[0].group_id,
labels=['Community'],
created_at=now,
summary=summary,
)
community_edges = build_community_edges(community_cluster, community_node, now)
logger.info((community_node, community_edges))
return community_node, community_edges
async def build_communities(
driver: AsyncDriver, llm_client: LLMClient
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
projection = await build_community_projection(driver)
community_clusters = await get_community_clusters(driver, projection)
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
await asyncio.gather(
*[build_community(llm_client, cluster) for cluster in community_clusters]
)
)
community_nodes: list[CommunityNode] = []
community_edges: list[CommunityEdge] = []
for community in communities:
community_nodes.append(community[0])
community_edges.extend(community[1])
await destroy_projection(driver, projection)
return community_nodes, community_edges
async def remove_communities(driver: AsyncDriver):
await driver.execute_query("""
MATCH (c:Community)
DETACH DELETE c
""")

View File

@ -20,9 +20,9 @@ from datetime import datetime
from time import time
from typing import List
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
@ -50,6 +50,24 @@ def build_episodic_edges(
return edges
def build_community_edges(
entity_nodes: List[EntityNode],
community_node: CommunityNode,
created_at: datetime,
) -> List[CommunityEdge]:
edges: List[CommunityEdge] = [
CommunityEdge(
source_node_uuid=community_node.uuid,
target_node_uuid=node.uuid,
created_at=created_at,
group_id=community_node.group_id,
)
for node in entity_nodes
]
return edges
async def extract_edges(
llm_client: LLMClient,
episode: EpisodicNode,

View File

@ -32,8 +32,10 @@ async def build_indices_and_constraints(driver: AsyncDriver):
range_indices: list[LiteralString] = [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
@ -51,6 +53,7 @@ async def build_indices_and_constraints(driver: AsyncDriver):
fulltext_indices: list[LiteralString] = [
'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
'CREATE FULLTEXT INDEX community_name IF NOT EXISTS FOR (n:Community) ON EACH [n.name]',
'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
]
@ -71,6 +74,14 @@ async def build_indices_and_constraints(driver: AsyncDriver):
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE VECTOR INDEX community_name_embedding IF NOT EXISTS
FOR (n:Community) ON (n.name_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
]
index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices

View File

@ -73,6 +73,7 @@ def format_context(facts):
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
await graphiti.build_communities()
edges = await graphiti.search('Freakenomics guest', group_ids=['1'])