2024-08-26 13:11:50 -04:00
|
|
|
"""
|
|
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
"""
|
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
import asyncio
|
2024-08-18 13:22:31 -04:00
|
|
|
import logging
|
2024-08-15 11:04:57 -04:00
|
|
|
import os
|
2024-08-22 12:26:13 -07:00
|
|
|
import sys
|
|
|
|
from datetime import datetime
|
2024-08-15 11:04:57 -04:00
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
import pytest
|
2024-08-22 12:26:13 -07:00
|
|
|
from dotenv import load_dotenv
|
2024-08-15 11:04:57 -04:00
|
|
|
|
2024-08-25 10:07:50 -07:00
|
|
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
|
|
from graphiti_core.graphiti import Graphiti
|
|
|
|
from graphiti_core.nodes import EntityNode, EpisodicNode
|
2024-09-16 14:03:05 -04:00
|
|
|
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
|
2024-08-15 11:04:57 -04:00
|
|
|
|
2024-08-22 12:26:13 -07:00
|
|
|
pytestmark = pytest.mark.integration
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
pytest_plugins = ('pytest_asyncio',)
|
2024-08-15 11:04:57 -04:00
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
2024-08-22 13:06:42 -07:00
|
|
|
NEO4J_URI = os.getenv('NEO4J_URI')
|
|
|
|
NEO4j_USER = os.getenv('NEO4J_USER')
|
|
|
|
NEO4j_PASSWORD = os.getenv('NEO4J_PASSWORD')
|
2024-08-15 11:04:57 -04:00
|
|
|
|
|
|
|
|
2024-08-18 13:22:31 -04:00
|
|
|
def setup_logging():
|
2024-08-23 14:18:45 -04:00
|
|
|
# Create a logger
|
|
|
|
logger = logging.getLogger()
|
|
|
|
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
# Create console handler and set level to INFO
|
|
|
|
console_handler = logging.StreamHandler(sys.stdout)
|
|
|
|
console_handler.setLevel(logging.INFO)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
# Create formatter
|
|
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
# Add formatter to console handler
|
|
|
|
console_handler.setFormatter(formatter)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
# Add console handler to logger
|
|
|
|
logger.addHandler(console_handler)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
return logger
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
2024-08-22 14:26:26 -04:00
|
|
|
def format_context(facts):
|
2024-08-23 14:18:45 -04:00
|
|
|
formatted_string = ''
|
|
|
|
formatted_string += 'FACTS:\n'
|
|
|
|
for fact in facts:
|
|
|
|
formatted_string += f' - {fact}\n'
|
|
|
|
formatted_string += '\n'
|
2024-08-22 14:26:26 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
return formatted_string.strip()
|
2024-08-18 13:22:31 -04:00
|
|
|
|
|
|
|
|
2024-08-15 11:04:57 -04:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_graphiti_init():
|
2024-08-23 14:18:45 -04:00
|
|
|
logger = setup_logging()
|
2024-08-27 16:18:01 -04:00
|
|
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
2024-09-23 11:05:44 -04:00
|
|
|
await graphiti.build_communities()
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-09-24 15:55:30 -04:00
|
|
|
edges = await graphiti.search(
|
|
|
|
'tania tetlow', center_node_uuid='4bf7ebb3-3a98-46c7-90a6-8e516c487961', group_ids=None
|
|
|
|
)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-26 17:24:35 -07:00
|
|
|
logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges]))
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-09-24 15:55:30 -04:00
|
|
|
edges = await graphiti.search('issues with higher ed', group_ids=None)
|
2024-08-18 13:22:31 -04:00
|
|
|
|
2024-08-26 17:24:35 -07:00
|
|
|
logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges]))
|
2024-09-16 14:03:05 -04:00
|
|
|
|
|
|
|
results = await graphiti._search(
|
2024-09-24 15:55:30 -04:00
|
|
|
'issues with higher ed', COMBINED_HYBRID_SEARCH_RRF, group_ids=None
|
2024-09-16 14:03:05 -04:00
|
|
|
)
|
|
|
|
pretty_results = {
|
|
|
|
'edges': [edge.fact for edge in results.edges],
|
|
|
|
'nodes': [node.name for node in results.nodes],
|
|
|
|
'communities': [community.name for community in results.communities],
|
|
|
|
}
|
|
|
|
|
|
|
|
logger.info(pretty_results)
|
2024-08-23 14:18:45 -04:00
|
|
|
graphiti.close()
|
2024-08-15 11:04:57 -04:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_graph_integration():
|
2024-08-27 16:18:01 -04:00
|
|
|
client = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
|
|
|
embedder = client.llm_client.get_embedder()
|
|
|
|
driver = client.driver
|
2024-08-23 14:18:45 -04:00
|
|
|
|
|
|
|
now = datetime.now()
|
|
|
|
episode = EpisodicNode(
|
|
|
|
name='test_episode',
|
|
|
|
labels=[],
|
|
|
|
created_at=now,
|
2024-08-26 20:00:28 -07:00
|
|
|
valid_at=now,
|
2024-08-23 14:18:45 -04:00
|
|
|
source='message',
|
|
|
|
source_description='conversation message',
|
|
|
|
content='Alice likes Bob',
|
|
|
|
entity_edges=[],
|
|
|
|
)
|
|
|
|
|
|
|
|
alice_node = EntityNode(
|
|
|
|
name='Alice',
|
|
|
|
labels=[],
|
|
|
|
created_at=now,
|
|
|
|
summary='Alice summary',
|
|
|
|
)
|
|
|
|
|
|
|
|
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
|
|
|
|
|
|
|
|
episodic_edge_1 = EpisodicEdge(
|
|
|
|
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
|
|
|
|
)
|
|
|
|
|
|
|
|
episodic_edge_2 = EpisodicEdge(
|
|
|
|
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
|
|
|
|
)
|
|
|
|
|
|
|
|
entity_edge = EntityEdge(
|
|
|
|
source_node_uuid=alice_node.uuid,
|
|
|
|
target_node_uuid=bob_node.uuid,
|
|
|
|
created_at=now,
|
|
|
|
name='likes',
|
|
|
|
fact='Alice likes Bob',
|
|
|
|
episodes=[],
|
|
|
|
expired_at=now,
|
|
|
|
valid_at=now,
|
|
|
|
invalid_at=now,
|
|
|
|
)
|
|
|
|
|
2024-08-27 16:18:01 -04:00
|
|
|
await entity_edge.generate_embedding(embedder)
|
2024-08-23 14:18:45 -04:00
|
|
|
|
|
|
|
nodes = [episode, alice_node, bob_node]
|
|
|
|
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
|
|
|
|
|
2024-08-27 16:18:01 -04:00
|
|
|
# test save
|
2024-08-23 14:18:45 -04:00
|
|
|
await asyncio.gather(*[node.save(driver) for node in nodes])
|
|
|
|
await asyncio.gather(*[edge.save(driver) for edge in edges])
|
2024-08-27 16:18:01 -04:00
|
|
|
|
|
|
|
# test get
|
|
|
|
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
|
|
|
|
assert await EntityNode.get_by_uuid(driver, alice_node.uuid) is not None
|
|
|
|
assert await EpisodicEdge.get_by_uuid(driver, episodic_edge_1.uuid) is not None
|
|
|
|
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
|
|
|
|
|
|
|
|
# test delete
|
|
|
|
await asyncio.gather(*[node.delete(driver) for node in nodes])
|
|
|
|
await asyncio.gather(*[edge.delete(driver) for edge in edges])
|