2025-04-20 23:20:19 -04:00
|
|
|
"""
|
|
|
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
from datetime import datetime
|
2025-04-20 23:20:19 -04:00
|
|
|
from uuid import uuid4
|
|
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
import numpy as np
|
2025-04-20 23:20:19 -04:00
|
|
|
import pytest
|
|
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
from graphiti_core.driver.driver import GraphDriver
|
2025-04-20 23:20:19 -04:00
|
|
|
from graphiti_core.nodes import (
|
|
|
|
|
CommunityNode,
|
|
|
|
|
EntityNode,
|
|
|
|
|
EpisodeType,
|
|
|
|
|
EpisodicNode,
|
|
|
|
|
)
|
2025-07-29 06:07:34 -07:00
|
|
|
from tests.helpers_test import drivers, get_driver
|
2025-04-20 23:20:19 -04:00
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
group_id = f'test_group_{str(uuid4())}'
|
2025-04-20 23:20:19 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def sample_entity_node():
|
|
|
|
|
return EntityNode(
|
|
|
|
|
uuid=str(uuid4()),
|
|
|
|
|
name='Test Entity',
|
2025-07-29 06:07:34 -07:00
|
|
|
group_id=group_id,
|
|
|
|
|
labels=[],
|
2025-04-20 23:20:19 -04:00
|
|
|
name_embedding=[0.5] * 1024,
|
|
|
|
|
summary='Entity Summary',
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def sample_episodic_node():
|
|
|
|
|
return EpisodicNode(
|
|
|
|
|
uuid=str(uuid4()),
|
|
|
|
|
name='Episode 1',
|
2025-07-29 06:07:34 -07:00
|
|
|
group_id=group_id,
|
2025-04-20 23:20:19 -04:00
|
|
|
source=EpisodeType.text,
|
|
|
|
|
source_description='Test source',
|
|
|
|
|
content='Some content here',
|
2025-07-29 06:07:34 -07:00
|
|
|
valid_at=datetime.now(),
|
2025-04-20 23:20:19 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def sample_community_node():
|
|
|
|
|
return CommunityNode(
|
|
|
|
|
uuid=str(uuid4()),
|
|
|
|
|
name='Community A',
|
|
|
|
|
name_embedding=[0.5] * 1024,
|
2025-07-29 06:07:34 -07:00
|
|
|
group_id=group_id,
|
2025-04-20 23:20:19 -04:00
|
|
|
summary='Community summary',
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2025-07-29 06:07:34 -07:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
'driver',
|
|
|
|
|
drivers,
|
|
|
|
|
ids=drivers,
|
|
|
|
|
)
|
|
|
|
|
async def test_entity_node(sample_entity_node, driver):
|
|
|
|
|
driver = get_driver(driver)
|
|
|
|
|
uuid = sample_entity_node.uuid
|
|
|
|
|
|
|
|
|
|
# Create node
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 0
|
|
|
|
|
await sample_entity_node.save(driver)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 1
|
|
|
|
|
|
|
|
|
|
retrieved = await EntityNode.get_by_uuid(driver, sample_entity_node.uuid)
|
2025-04-20 23:20:19 -04:00
|
|
|
assert retrieved.uuid == sample_entity_node.uuid
|
|
|
|
|
assert retrieved.name == 'Test Entity'
|
2025-07-29 06:07:34 -07:00
|
|
|
assert retrieved.group_id == group_id
|
2025-04-20 23:20:19 -04:00
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
retrieved = await EntityNode.get_by_uuids(driver, [sample_entity_node.uuid])
|
|
|
|
|
assert retrieved[0].uuid == sample_entity_node.uuid
|
|
|
|
|
assert retrieved[0].name == 'Test Entity'
|
|
|
|
|
assert retrieved[0].group_id == group_id
|
2025-04-20 23:20:19 -04:00
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
retrieved = await EntityNode.get_by_group_ids(driver, [group_id], limit=2)
|
|
|
|
|
assert len(retrieved) == 1
|
|
|
|
|
assert retrieved[0].uuid == sample_entity_node.uuid
|
|
|
|
|
assert retrieved[0].name == 'Test Entity'
|
|
|
|
|
assert retrieved[0].group_id == group_id
|
2025-04-20 23:20:19 -04:00
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
await sample_entity_node.load_name_embedding(driver)
|
|
|
|
|
assert np.allclose(sample_entity_node.name_embedding, [0.5] * 1024)
|
2025-04-20 23:20:19 -04:00
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
# Delete node by uuid
|
|
|
|
|
await sample_entity_node.delete(driver)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 0
|
|
|
|
|
|
|
|
|
|
# Delete node by group id
|
|
|
|
|
await sample_entity_node.save(driver)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 1
|
|
|
|
|
await sample_entity_node.delete_by_group_id(driver, group_id)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 0
|
|
|
|
|
|
|
|
|
|
await driver.close()
|
2025-04-20 23:20:19 -04:00
|
|
|
|
|
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
'driver',
|
|
|
|
|
drivers,
|
|
|
|
|
ids=drivers,
|
|
|
|
|
)
|
|
|
|
|
async def test_community_node(sample_community_node, driver):
|
|
|
|
|
driver = get_driver(driver)
|
|
|
|
|
uuid = sample_community_node.uuid
|
|
|
|
|
|
|
|
|
|
# Create node
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 0
|
|
|
|
|
await sample_community_node.save(driver)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 1
|
|
|
|
|
|
|
|
|
|
retrieved = await CommunityNode.get_by_uuid(driver, sample_community_node.uuid)
|
2025-04-20 23:20:19 -04:00
|
|
|
assert retrieved.uuid == sample_community_node.uuid
|
|
|
|
|
assert retrieved.name == 'Community A'
|
2025-07-29 06:07:34 -07:00
|
|
|
assert retrieved.group_id == group_id
|
2025-04-20 23:20:19 -04:00
|
|
|
assert retrieved.summary == 'Community summary'
|
|
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
retrieved = await CommunityNode.get_by_uuids(driver, [sample_community_node.uuid])
|
|
|
|
|
assert retrieved[0].uuid == sample_community_node.uuid
|
|
|
|
|
assert retrieved[0].name == 'Community A'
|
|
|
|
|
assert retrieved[0].group_id == group_id
|
|
|
|
|
assert retrieved[0].summary == 'Community summary'
|
2025-04-20 23:20:19 -04:00
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
retrieved = await CommunityNode.get_by_group_ids(driver, [group_id], limit=2)
|
|
|
|
|
assert len(retrieved) == 1
|
|
|
|
|
assert retrieved[0].uuid == sample_community_node.uuid
|
|
|
|
|
assert retrieved[0].name == 'Community A'
|
|
|
|
|
assert retrieved[0].group_id == group_id
|
2025-04-20 23:20:19 -04:00
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
# Delete node by uuid
|
|
|
|
|
await sample_community_node.delete(driver)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 0
|
2025-04-20 23:20:19 -04:00
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
# Delete node by group id
|
|
|
|
|
await sample_community_node.save(driver)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 1
|
|
|
|
|
await sample_community_node.delete_by_group_id(driver, group_id)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 0
|
2025-04-20 23:20:19 -04:00
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
await driver.close()
|
2025-04-20 23:20:19 -04:00
|
|
|
|
2025-07-29 06:07:34 -07:00
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
'driver',
|
|
|
|
|
drivers,
|
|
|
|
|
ids=drivers,
|
|
|
|
|
)
|
|
|
|
|
async def test_episodic_node(sample_episodic_node, driver):
|
|
|
|
|
driver = get_driver(driver)
|
|
|
|
|
uuid = sample_episodic_node.uuid
|
|
|
|
|
|
|
|
|
|
# Create node
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 0
|
|
|
|
|
await sample_episodic_node.save(driver)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 1
|
|
|
|
|
|
|
|
|
|
retrieved = await EpisodicNode.get_by_uuid(driver, sample_episodic_node.uuid)
|
2025-04-20 23:20:19 -04:00
|
|
|
assert retrieved.uuid == sample_episodic_node.uuid
|
|
|
|
|
assert retrieved.name == 'Episode 1'
|
2025-07-29 06:07:34 -07:00
|
|
|
assert retrieved.group_id == group_id
|
2025-04-20 23:20:19 -04:00
|
|
|
assert retrieved.source == EpisodeType.text
|
|
|
|
|
assert retrieved.source_description == 'Test source'
|
|
|
|
|
assert retrieved.content == 'Some content here'
|
2025-07-29 06:07:34 -07:00
|
|
|
assert retrieved.valid_at == sample_episodic_node.valid_at
|
|
|
|
|
|
|
|
|
|
retrieved = await EpisodicNode.get_by_uuids(driver, [sample_episodic_node.uuid])
|
|
|
|
|
assert retrieved[0].uuid == sample_episodic_node.uuid
|
|
|
|
|
assert retrieved[0].name == 'Episode 1'
|
|
|
|
|
assert retrieved[0].group_id == group_id
|
|
|
|
|
assert retrieved[0].source == EpisodeType.text
|
|
|
|
|
assert retrieved[0].source_description == 'Test source'
|
|
|
|
|
assert retrieved[0].content == 'Some content here'
|
|
|
|
|
assert retrieved[0].valid_at == sample_episodic_node.valid_at
|
|
|
|
|
|
|
|
|
|
retrieved = await EpisodicNode.get_by_group_ids(driver, [group_id], limit=2)
|
|
|
|
|
assert len(retrieved) == 1
|
|
|
|
|
assert retrieved[0].uuid == sample_episodic_node.uuid
|
|
|
|
|
assert retrieved[0].name == 'Episode 1'
|
|
|
|
|
assert retrieved[0].group_id == group_id
|
|
|
|
|
assert retrieved[0].source == EpisodeType.text
|
|
|
|
|
assert retrieved[0].source_description == 'Test source'
|
|
|
|
|
assert retrieved[0].content == 'Some content here'
|
|
|
|
|
assert retrieved[0].valid_at == sample_episodic_node.valid_at
|
|
|
|
|
|
|
|
|
|
# Delete node by uuid
|
|
|
|
|
await sample_episodic_node.delete(driver)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 0
|
|
|
|
|
|
|
|
|
|
# Delete node by group id
|
|
|
|
|
await sample_episodic_node.save(driver)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 1
|
|
|
|
|
await sample_episodic_node.delete_by_group_id(driver, group_id)
|
|
|
|
|
node_count = await get_node_count(driver, uuid)
|
|
|
|
|
assert node_count == 0
|
|
|
|
|
|
|
|
|
|
await driver.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_node_count(driver: GraphDriver, uuid: str):
|
|
|
|
|
result, _, _ = await driver.execute_query(
|
|
|
|
|
"""
|
|
|
|
|
MATCH (n {uuid: $uuid})
|
|
|
|
|
RETURN COUNT(n) as count
|
|
|
|
|
""",
|
|
|
|
|
uuid=uuid,
|
|
|
|
|
)
|
|
|
|
|
return int(result[0]['count'])
|