Node episodes list (#381)

* added episode list virtual field

* in progress tests

* add tests

* update search return type

* linter

* copyright notice

* mark integration tests
This commit is contained in:
Preston Rasmussen 2025-04-20 23:20:19 -04:00 committed by GitHub
parent 064d9207d2
commit 009467650f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 186 additions and 75 deletions

View File

@ -38,6 +38,20 @@ from graphiti_core.utils.datetime_utils import utc_now
logger = logging.getLogger(__name__)
ENTITY_NODE_RETURN: LiteralString = """
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
WITH n, collect(e.uuid) AS episodes
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes,
episodes"""
class EpisodeType(Enum):
"""
@ -280,6 +294,9 @@ class EpisodicNode(Node):
class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
episodes: list[str] | None = Field(
default=None, description='List of episode uuids that mention this node.'
)
attributes: dict[str, Any] = Field(
default={}, description='Additional attributes of the node. Dependent on node labels'
)
@ -318,19 +335,14 @@ class EntityNode(Node):
@classmethod
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
records, _, _ = await driver.execute_query(
query = (
"""
MATCH (n:Entity {uuid: $uuid})
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
""",
MATCH (n:Entity {uuid: $uuid})
"""
+ ENTITY_NODE_RETURN
)
records, _, _ = await driver.execute_query(
query,
uuid=uuid,
database_=DEFAULT_DATABASE,
routing_='r',
@ -348,16 +360,8 @@ class EntityNode(Node):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity) WHERE n.uuid IN $uuids
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
""",
"""
+ ENTITY_NODE_RETURN,
uuids=uuids,
database_=DEFAULT_DATABASE,
routing_='r',
@ -383,16 +387,8 @@ class EntityNode(Node):
MATCH (n:Entity) WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ ENTITY_NODE_RETURN
+ """
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
ORDER BY n.uuid DESC
"""
+ limit_query,
@ -548,6 +544,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
created_at=record['created_at'].to_native(),
summary=record['summary'],
attributes=record['attributes'],
episodes=record['episodes'],
)
entity_node.attributes.pop('uuid', None)

View File

@ -32,6 +32,7 @@ from graphiti_core.helpers import (
semaphore_gather,
)
from graphiti_core.nodes import (
ENTITY_NODE_RETURN,
CommunityNode,
EntityNode,
EpisodicNode,
@ -53,6 +54,20 @@ DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 32
SEARCH_ENTITY_NODE_RETURN: LiteralString = """
OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n)
WITH n, score, collect(e.uuid) AS episodes
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes,
episodes"""
def fulltext_query(query: str, group_ids: list[str] | None = None):
group_ids_filter_list = (
@ -230,8 +245,8 @@ async def edge_similarity_search(
query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
@ -341,27 +356,21 @@ async def node_fulltext_search(
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS node, score
MATCH (n:Entity)
WHERE n.uuid = node.uuid
query = (
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
WHERE n:Entity
"""
+ filter_query
+ SEARCH_ENTITY_NODE_RETURN
+ """
RETURN
n.uuid AS uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
ORDER BY score DESC
LIMIT $limit
""",
"""
)
records, _, _ = await driver.execute_query(
query,
filter_params,
query=fuzzy_query,
group_ids=group_ids,
@ -406,19 +415,12 @@ async def node_similarity_search(
+ filter_query
+ """
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > $min_score
RETURN
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
ORDER BY score DESC
LIMIT $limit
""",
WHERE score > $min_score"""
+ SEARCH_ENTITY_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
""",
query_params,
search_vector=search_vector,
group_ids=group_ids,
@ -452,16 +454,8 @@ async def node_bfs_search(
WHERE n.group_id = origin.group_id
"""
+ filter_query
+ ENTITY_NODE_RETURN
+ """
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
LIMIT $limit
""",
filter_params,

View File

@ -65,9 +65,7 @@ async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
results = await graphiti.search_(
query='Who is the User?',
)
results = await graphiti.search_(query='Who is the User?')
pretty_results = search_results_to_context_string(results)

122
tests/test_node_int.py Normal file
View File

@ -0,0 +1,122 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from neo4j import AsyncGraphDatabase
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodeType,
EpisodicNode,
)
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
@pytest.fixture
def sample_entity_node():
return EntityNode(
uuid=str(uuid4()),
name='Test Entity',
group_id='test_group',
labels=['Entity'],
name_embedding=[0.5] * 1024,
summary='Entity Summary',
)
@pytest.fixture
def sample_episodic_node():
return EpisodicNode(
uuid=str(uuid4()),
name='Episode 1',
group_id='test_group',
source=EpisodeType.text,
source_description='Test source',
content='Some content here',
valid_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_community_node():
return CommunityNode(
uuid=str(uuid4()),
name='Community A',
name_embedding=[0.5] * 1024,
group_id='test_group',
summary='Community summary',
)
@pytest.mark.asyncio
@pytest.mark.integration
async def test_entity_node_save_get_and_delete(sample_entity_node):
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
await sample_entity_node.save(neo4j_driver)
retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid)
assert retrieved.uuid == sample_entity_node.uuid
assert retrieved.name == 'Test Entity'
assert retrieved.group_id == 'test_group'
await sample_entity_node.delete(neo4j_driver)
await neo4j_driver.close()
@pytest.mark.asyncio
@pytest.mark.integration
async def test_community_node_save_get_and_delete(sample_community_node):
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
await sample_community_node.save(neo4j_driver)
retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid)
assert retrieved.uuid == sample_community_node.uuid
assert retrieved.name == 'Community A'
assert retrieved.group_id == 'test_group'
assert retrieved.summary == 'Community summary'
await sample_community_node.delete(neo4j_driver)
await neo4j_driver.close()
@pytest.mark.asyncio
@pytest.mark.integration
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
await sample_episodic_node.save(neo4j_driver)
retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid)
assert retrieved.uuid == sample_episodic_node.uuid
assert retrieved.name == 'Episode 1'
assert retrieved.group_id == 'test_group'
assert retrieved.source == EpisodeType.text
assert retrieved.source_description == 'Test source'
assert retrieved.content == 'Some content here'
await sample_episodic_node.delete(neo4j_driver)
await neo4j_driver.close()