graphiti/tests/test_node_int.py
Preston Rasmussen 1460172568
don't return index labels (#887)
* don't return index labels

* update tests
2025-09-02 12:02:33 -04:00

230 lines
7.4 KiB
Python

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime, timedelta
from uuid import uuid4
import pytest
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodeType,
EpisodicNode,
)
from tests.helpers_test import (
assert_community_node_equals,
assert_entity_node_equals,
assert_episodic_node_equals,
get_node_count,
group_id,
)
created_at = datetime.now()
deleted_at = created_at + timedelta(days=3)
valid_at = created_at + timedelta(days=1)
invalid_at = created_at + timedelta(days=2)
@pytest.fixture
def sample_entity_node():
return EntityNode(
uuid=str(uuid4()),
name='Test Entity',
group_id=group_id,
labels=['Entity', 'Person'],
created_at=created_at,
name_embedding=[0.5] * 1024,
summary='Entity Summary',
attributes={
'age': 30,
'location': 'New York',
},
)
@pytest.fixture
def sample_episodic_node():
return EpisodicNode(
uuid=str(uuid4()),
name='Episode 1',
group_id=group_id,
created_at=created_at,
source=EpisodeType.text,
source_description='Test source',
content='Some content here',
valid_at=valid_at,
entity_edges=[],
)
@pytest.fixture
def sample_community_node():
return CommunityNode(
uuid=str(uuid4()),
name='Community A',
group_id=group_id,
created_at=created_at,
name_embedding=[0.5] * 1024,
summary='Community summary',
)
@pytest.mark.asyncio
async def test_entity_node(sample_entity_node, graph_driver):
uuid = sample_entity_node.uuid
# Create node
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await sample_entity_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
# Get node by uuid
retrieved = await EntityNode.get_by_uuid(graph_driver, sample_entity_node.uuid)
await assert_entity_node_equals(graph_driver, retrieved, sample_entity_node)
# Get node by uuids
retrieved = await EntityNode.get_by_uuids(graph_driver, [sample_entity_node.uuid])
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
# Get node by group ids
retrieved = await EntityNode.get_by_group_ids(
graph_driver, [group_id], limit=2, with_embeddings=True
)
assert len(retrieved) == 1
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
# Delete node by uuid
await sample_entity_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_entity_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_entity_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
await sample_entity_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_entity_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
async def test_community_node(sample_community_node, graph_driver):
uuid = sample_community_node.uuid
# Create node
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await sample_community_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
# Get node by uuid
retrieved = await CommunityNode.get_by_uuid(graph_driver, sample_community_node.uuid)
await assert_community_node_equals(graph_driver, retrieved, sample_community_node)
# Get node by uuids
retrieved = await CommunityNode.get_by_uuids(graph_driver, [sample_community_node.uuid])
await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
# Get node by group ids
retrieved = await CommunityNode.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
# Delete node by uuid
await sample_community_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_community_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_community_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
await sample_community_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_community_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
async def test_episodic_node(sample_episodic_node, graph_driver):
uuid = sample_episodic_node.uuid
# Create node
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
# Get node by uuid
retrieved = await EpisodicNode.get_by_uuid(graph_driver, sample_episodic_node.uuid)
await assert_episodic_node_equals(retrieved, sample_episodic_node)
# Get node by uuids
retrieved = await EpisodicNode.get_by_uuids(graph_driver, [sample_episodic_node.uuid])
await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
# Get node by group ids
retrieved = await EpisodicNode.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
# Delete node by uuid
await sample_episodic_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_episodic_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_episodic_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await graph_driver.close()