diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 73b73b42..c20fdd5d 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -38,6 +38,20 @@ from graphiti_core.utils.datetime_utils import utc_now logger = logging.getLogger(__name__) +ENTITY_NODE_RETURN: LiteralString = """ + OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n) + WITH n, collect(e.uuid) AS episodes + RETURN + n.uuid As uuid, + n.name AS name, + n.name_embedding AS name_embedding, + n.group_id AS group_id, + n.created_at AS created_at, + n.summary AS summary, + labels(n) AS labels, + properties(n) AS attributes, + episodes""" + class EpisodeType(Enum): """ @@ -280,6 +294,9 @@ class EpisodicNode(Node): class EntityNode(Node): name_embedding: list[float] | None = Field(default=None, description='embedding of the name') summary: str = Field(description='regional summary of surrounding edges', default_factory=str) + episodes: list[str] | None = Field( + default=None, description='List of episode uuids that mention this node.' + ) attributes: dict[str, Any] = Field( default={}, description='Additional attributes of the node. Dependent on node labels' ) @@ -318,19 +335,14 @@ class EntityNode(Node): @classmethod async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): - records, _, _ = await driver.execute_query( + query = ( """ - MATCH (n:Entity {uuid: $uuid}) - RETURN - n.uuid As uuid, - n.name AS name, - n.name_embedding AS name_embedding, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes - """, + MATCH (n:Entity {uuid: $uuid}) + """ + + ENTITY_NODE_RETURN + ) + records, _, _ = await driver.execute_query( + query, uuid=uuid, database_=DEFAULT_DATABASE, routing_='r', @@ -348,16 +360,8 @@ class EntityNode(Node): records, _, _ = await driver.execute_query( """ MATCH (n:Entity) WHERE n.uuid IN $uuids - RETURN - n.uuid As uuid, - n.name AS name, - n.name_embedding AS name_embedding, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes - """, + """ + + ENTITY_NODE_RETURN, uuids=uuids, database_=DEFAULT_DATABASE, routing_='r', @@ -383,16 +387,8 @@ class EntityNode(Node): MATCH (n:Entity) WHERE n.group_id IN $group_ids """ + cursor_query + + ENTITY_NODE_RETURN + """ - RETURN - n.uuid As uuid, - n.name AS name, - n.name_embedding AS name_embedding, - n.group_id AS group_id, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes ORDER BY n.uuid DESC """ + limit_query, @@ -548,6 +544,7 @@ def get_entity_node_from_record(record: Any) -> EntityNode: created_at=record['created_at'].to_native(), summary=record['summary'], attributes=record['attributes'], + episodes=record['episodes'], ) entity_node.attributes.pop('uuid', None) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 86c21de5..4095a88b 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -32,6 +32,7 @@ from graphiti_core.helpers import ( semaphore_gather, ) from graphiti_core.nodes import ( + ENTITY_NODE_RETURN, CommunityNode, EntityNode, EpisodicNode, @@ -53,6 +54,20 @@ DEFAULT_MMR_LAMBDA = 0.5 MAX_SEARCH_DEPTH = 3 MAX_QUERY_LENGTH = 32 +SEARCH_ENTITY_NODE_RETURN: LiteralString = """ + OPTIONAL MATCH (e:Episodic)-[r:MENTIONS]->(n) + WITH n, score, collect(e.uuid) AS episodes + RETURN + n.uuid As uuid, + n.name AS name, + n.name_embedding AS name_embedding, + n.group_id AS group_id, + n.created_at AS created_at, + n.summary AS summary, + labels(n) AS labels, + properties(n) AS attributes, + episodes""" + def fulltext_query(query: str, group_ids: list[str] | None = None): group_ids_filter_list = ( @@ -230,8 +245,8 @@ async def edge_similarity_search( query: LiteralString = ( """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -341,27 +356,21 @@ async def node_fulltext_search( filter_query, filter_params = node_search_filter_query_constructor(search_filter) - records, _, _ = await driver.execute_query( - """ - CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS node, score - MATCH (n:Entity) - WHERE n.uuid = node.uuid + query = ( """ + CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) + YIELD node AS n, score + WHERE n:Entity + """ + filter_query + + SEARCH_ENTITY_NODE_RETURN + """ - RETURN - n.uuid AS uuid, - n.group_id AS group_id, - n.name AS name, - n.name_embedding AS name_embedding, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes ORDER BY score DESC - LIMIT $limit - """, + """ + ) + + records, _, _ = await driver.execute_query( + query, filter_params, query=fuzzy_query, group_ids=group_ids, @@ -406,19 +415,12 @@ async def node_similarity_search( + filter_query + """ WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score - WHERE score > $min_score - RETURN - n.uuid As uuid, - n.group_id AS group_id, - n.name AS name, - n.name_embedding AS name_embedding, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes - ORDER BY score DESC - LIMIT $limit - """, + WHERE score > $min_score""" + + SEARCH_ENTITY_NODE_RETURN + + """ + ORDER BY score DESC + LIMIT $limit + """, query_params, search_vector=search_vector, group_ids=group_ids, @@ -452,16 +454,8 @@ async def node_bfs_search( WHERE n.group_id = origin.group_id """ + filter_query + + ENTITY_NODE_RETURN + """ - RETURN DISTINCT - n.uuid As uuid, - n.group_id AS group_id, - n.name AS name, - n.name_embedding AS name_embedding, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes LIMIT $limit """, filter_params, diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index a2c7c378..3efd2697 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -65,9 +65,7 @@ async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD) - results = await graphiti.search_( - query='Who is the User?', - ) + results = await graphiti.search_(query='Who is the User?') pretty_results = search_results_to_context_string(results) diff --git a/tests/test_node_int.py b/tests/test_node_int.py new file mode 100644 index 00000000..9f50f18a --- /dev/null +++ b/tests/test_node_int.py @@ -0,0 +1,122 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +from datetime import datetime, timezone +from uuid import uuid4 + +import pytest +from neo4j import AsyncGraphDatabase + +from graphiti_core.nodes import ( + CommunityNode, + EntityNode, + EpisodeType, + EpisodicNode, +) + +NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687') +NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j') +NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test') + + +@pytest.fixture +def sample_entity_node(): + return EntityNode( + uuid=str(uuid4()), + name='Test Entity', + group_id='test_group', + labels=['Entity'], + name_embedding=[0.5] * 1024, + summary='Entity Summary', + ) + + +@pytest.fixture +def sample_episodic_node(): + return EpisodicNode( + uuid=str(uuid4()), + name='Episode 1', + group_id='test_group', + source=EpisodeType.text, + source_description='Test source', + content='Some content here', + valid_at=datetime.now(timezone.utc), + ) + + +@pytest.fixture +def sample_community_node(): + return CommunityNode( + uuid=str(uuid4()), + name='Community A', + name_embedding=[0.5] * 1024, + group_id='test_group', + summary='Community summary', + ) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_entity_node_save_get_and_delete(sample_entity_node): + neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) + await sample_entity_node.save(neo4j_driver) + retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid) + assert retrieved.uuid == sample_entity_node.uuid + assert retrieved.name == 'Test Entity' + assert retrieved.group_id == 'test_group' + + await sample_entity_node.delete(neo4j_driver) + + await neo4j_driver.close() + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_community_node_save_get_and_delete(sample_community_node): + neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) + + await sample_community_node.save(neo4j_driver) + + retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid) + assert retrieved.uuid == sample_community_node.uuid + assert retrieved.name == 'Community A' + assert retrieved.group_id == 'test_group' + assert retrieved.summary == 'Community summary' + + await sample_community_node.delete(neo4j_driver) + + await neo4j_driver.close() + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_episodic_node_save_get_and_delete(sample_episodic_node): + neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) + + await sample_episodic_node.save(neo4j_driver) + + retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid) + assert retrieved.uuid == sample_episodic_node.uuid + assert retrieved.name == 'Episode 1' + assert retrieved.group_id == 'test_group' + assert retrieved.source == EpisodeType.text + assert retrieved.source_description == 'Test source' + assert retrieved.content == 'Some content here' + + await sample_episodic_node.delete(neo4j_driver) + + await neo4j_driver.close()