graphiti/tests/test_node_int.py

123 lines
3.6 KiB
Python
Raw Normal View History

"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from neo4j import AsyncGraphDatabase
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodeType,
EpisodicNode,
)
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
@pytest.fixture
def sample_entity_node():
return EntityNode(
uuid=str(uuid4()),
name='Test Entity',
group_id='test_group',
labels=['Entity'],
name_embedding=[0.5] * 1024,
summary='Entity Summary',
)
@pytest.fixture
def sample_episodic_node():
return EpisodicNode(
uuid=str(uuid4()),
name='Episode 1',
group_id='test_group',
source=EpisodeType.text,
source_description='Test source',
content='Some content here',
valid_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_community_node():
return CommunityNode(
uuid=str(uuid4()),
name='Community A',
name_embedding=[0.5] * 1024,
group_id='test_group',
summary='Community summary',
)
@pytest.mark.asyncio
@pytest.mark.integration
async def test_entity_node_save_get_and_delete(sample_entity_node):
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
await sample_entity_node.save(neo4j_driver)
retrieved = await EntityNode.get_by_uuid(neo4j_driver, sample_entity_node.uuid)
assert retrieved.uuid == sample_entity_node.uuid
assert retrieved.name == 'Test Entity'
assert retrieved.group_id == 'test_group'
await sample_entity_node.delete(neo4j_driver)
await neo4j_driver.close()
@pytest.mark.asyncio
@pytest.mark.integration
async def test_community_node_save_get_and_delete(sample_community_node):
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
await sample_community_node.save(neo4j_driver)
retrieved = await CommunityNode.get_by_uuid(neo4j_driver, sample_community_node.uuid)
assert retrieved.uuid == sample_community_node.uuid
assert retrieved.name == 'Community A'
assert retrieved.group_id == 'test_group'
assert retrieved.summary == 'Community summary'
await sample_community_node.delete(neo4j_driver)
await neo4j_driver.close()
@pytest.mark.asyncio
@pytest.mark.integration
async def test_episodic_node_save_get_and_delete(sample_episodic_node):
neo4j_driver = AsyncGraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
await sample_episodic_node.save(neo4j_driver)
retrieved = await EpisodicNode.get_by_uuid(neo4j_driver, sample_episodic_node.uuid)
assert retrieved.uuid == sample_episodic_node.uuid
assert retrieved.name == 'Episode 1'
assert retrieved.group_id == 'test_group'
assert retrieved.source == EpisodeType.text
assert retrieved.source_description == 'Test source'
assert retrieved.content == 'Some content here'
await sample_episodic_node.delete(neo4j_driver)
await neo4j_driver.close()