graphiti/tests/tests_int_graphiti.py
Daniel Chalef c5e52153c4
chore: Fix packaging (#38)
* feat: Update project name and description

The project name and description in the `pyproject.toml` file have been updated to reflect the changes made to the project.

* chore: Update pyproject.toml to include core package

The `pyproject.toml` file has been updated to include the `core` package in the list of packages. This change ensures that the `core` package is included when building the project.

* fix imports

* fix importats
2024-08-25 10:07:50 -07:00

132 lines
3.4 KiB
Python

import asyncio
import logging
import os
import sys
from datetime import datetime
import pytest
from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase
from openai import OpenAI
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.nodes import EntityNode, EpisodicNode
pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)
load_dotenv()
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4j_USER = os.getenv('NEO4J_USER')
NEO4j_PASSWORD = os.getenv('NEO4J_PASSWORD')
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
def format_context(facts):
formatted_string = ''
formatted_string += 'FACTS:\n'
for fact in facts:
formatted_string += f' - {fact}\n'
formatted_string += '\n'
return formatted_string.strip()
@pytest.mark.asyncio
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)
facts = await graphiti.search('Freakenomics guest')
logger.info('\nQUERY: Freakenomics guest\n' + format_context(facts))
facts = await graphiti.search('tania tetlow\n')
logger.info('\nQUERY: Tania Tetlow\n' + format_context(facts))
facts = await graphiti.search('issues with higher ed')
logger.info('\nQUERY: issues with higher ed\n' + format_context(facts))
graphiti.close()
@pytest.mark.asyncio
async def test_graph_integration():
driver = AsyncGraphDatabase.driver(
NEO4J_URI,
auth=(NEO4j_USER, NEO4j_PASSWORD),
)
embedder = OpenAI().embeddings
now = datetime.now()
episode = EpisodicNode(
name='test_episode',
labels=[],
created_at=now,
source='message',
source_description='conversation message',
content='Alice likes Bob',
entity_edges=[],
)
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
)
bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary')
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now
)
episodic_edge_2 = EpisodicEdge(
source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now
)
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
)
entity_edge.generate_embedding(embedder)
nodes = [episode, alice_node, bob_node]
edges = [episodic_edge_1, episodic_edge_2, entity_edge]
await asyncio.gather(*[node.save(driver) for node in nodes])
await asyncio.gather(*[edge.save(driver) for edge in edges])