mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
feat: Initial version of temporal invalidation + tests (#8)
* feat: Initial version of temporal invalidation + tests * fix: dont run int tests on CI * fix: dont run int tests on CI * fix: dont run int tests on CI * fix: time of day issue * fix: running non int tests in ci * fix: running non int tests in ci * fix: running non int tests in ci * fix: running non int tests in ci * fix: running non int tests in ci * fix: running non int tests in ci * fix: running non int tests in ci * revert: Tests structural changes * chore: Remove idea file * chore: Get rid of NodesWithEdges class and define a triplet type instead
This commit is contained in:
parent
40e74a2e97
commit
a6fd0ddb75
39
.github/workflows/unit_tests.yml
vendored
Normal file
39
.github/workflows/unit_tests.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- name: Load cached Poetry installation
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.local
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
- name: Load cached dependencies
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
|
||||
- name: Install dependencies
|
||||
run: poetry install --no-interaction --no-root
|
||||
- name: Run non-integration tests
|
||||
env:
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
run: |
|
||||
poetry run pytest -m "not integration"
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -158,5 +158,5 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
.idea/
|
||||
.vscode/
|
||||
|
6
conftest.py
Normal file
6
conftest.py
Normal file
@ -0,0 +1,6 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
# This code adds the project root directory to the Python path, allowing imports to work correctly when running tests.
|
||||
# Without this file, you might encounter ModuleNotFoundError when trying to import modules from your project, especially when running tests.
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))
|
@ -8,18 +8,22 @@ import os
|
||||
|
||||
from core.llm_client.config import EMBEDDING_DIM
|
||||
from core.nodes import EntityNode, EpisodicNode, Node
|
||||
from core.edges import EntityEdge, Edge, EpisodicEdge
|
||||
from core.edges import EntityEdge, EpisodicEdge
|
||||
from core.utils import (
|
||||
build_episodic_edges,
|
||||
retrieve_relevant_schema,
|
||||
extract_new_edges,
|
||||
extract_new_nodes,
|
||||
clear_data,
|
||||
retrieve_episodes,
|
||||
)
|
||||
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
|
||||
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
|
||||
from core.utils.maintenance.edge_operations import (
|
||||
extract_edges,
|
||||
dedupe_extracted_edges,
|
||||
)
|
||||
|
||||
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
||||
from core.utils.maintenance.temporal_operations import (
|
||||
prepare_edges_for_invalidation,
|
||||
invalidate_edges,
|
||||
)
|
||||
from core.utils.search.search_utils import (
|
||||
edge_similarity_search,
|
||||
entity_fulltext_search,
|
||||
@ -59,21 +63,6 @@ class Graphiti:
|
||||
"""Retrieve the last n episodic nodes from the graph"""
|
||||
return await retrieve_episodes(self.driver, last_n, sources)
|
||||
|
||||
async def retrieve_relevant_schema(self, query: str = None) -> dict[str, any]:
|
||||
"""Retrieve relevant nodes and edges to a specific query"""
|
||||
return await retrieve_relevant_schema(self.driver, query)
|
||||
...
|
||||
|
||||
# Invalidate edges that are no longer valid
|
||||
async def invalidate_edges(
|
||||
self,
|
||||
episode: EpisodicNode,
|
||||
new_nodes: list[EntityNode],
|
||||
new_edges: list[EntityEdge],
|
||||
relevant_schema: dict[str, any],
|
||||
previous_episodes: list[EpisodicNode],
|
||||
): ...
|
||||
|
||||
async def add_episode(
|
||||
self,
|
||||
name: str,
|
||||
@ -102,7 +91,6 @@ class Graphiti:
|
||||
created_at=now,
|
||||
valid_at=reference_time,
|
||||
)
|
||||
# relevant_schema = await self.retrieve_relevant_schema(episode.content)
|
||||
|
||||
extracted_nodes = await extract_nodes(
|
||||
self.llm_client, episode, previous_episodes
|
||||
@ -139,13 +127,32 @@ class Graphiti:
|
||||
f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}"
|
||||
)
|
||||
|
||||
new_edges = await dedupe_extracted_edges(
|
||||
deduped_edges = await dedupe_extracted_edges(
|
||||
self.llm_client, extracted_edges, existing_edges
|
||||
)
|
||||
|
||||
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")
|
||||
(
|
||||
old_edges_with_nodes_pending_invalidation,
|
||||
new_edges_with_nodes,
|
||||
) = prepare_edges_for_invalidation(
|
||||
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
|
||||
)
|
||||
|
||||
entity_edges.extend(new_edges)
|
||||
invalidated_edges = await invalidate_edges(
|
||||
self.llm_client,
|
||||
old_edges_with_nodes_pending_invalidation,
|
||||
new_edges_with_nodes,
|
||||
)
|
||||
|
||||
entity_edges.extend(invalidated_edges)
|
||||
|
||||
logger.info(
|
||||
f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}"
|
||||
)
|
||||
|
||||
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")
|
||||
|
||||
entity_edges.extend(deduped_edges)
|
||||
episodic_edges.extend(
|
||||
build_episodic_edges(
|
||||
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
||||
|
@ -71,7 +71,9 @@ class EntityNode(Node):
|
||||
name_embedding: list[float] | None = Field(
|
||||
default=None, description="embedding of the name"
|
||||
)
|
||||
summary: str = Field(description="regional summary of surrounding edges")
|
||||
summary: str = Field(
|
||||
description="regional summary of surrounding edges", default_factory=str
|
||||
)
|
||||
|
||||
async def update_summary(self, driver: AsyncDriver): ...
|
||||
|
||||
|
50
core/prompts/invalidate_edges.py
Normal file
50
core/prompts/invalidate_edges.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import Protocol, TypedDict
|
||||
from .models import Message, PromptVersion, PromptFunction
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role="system",
|
||||
content="You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based on newer information.",
|
||||
),
|
||||
Message(
|
||||
role="user",
|
||||
content=f"""
|
||||
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
|
||||
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
|
||||
Do not invalidate relationships merely because they weren't mentioned in new edges.
|
||||
|
||||
Existing Edges (sorted by timestamp, newest first):
|
||||
{context['existing_edges']}
|
||||
|
||||
New Edges:
|
||||
{context['new_edges']}
|
||||
|
||||
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (TIMESTAMP)"
|
||||
|
||||
For each existing edge that should be invalidated, respond with a JSON object in the following format:
|
||||
{{
|
||||
"invalidated_edges": [
|
||||
{{
|
||||
"edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
|
||||
"reason": "Brief explanation of why this edge is being invalidated"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
If no relationships need to be invalidated, return an empty list for "invalidated_edges".
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {"v1": v1}
|
@ -26,12 +26,19 @@ from .dedupe_edges import (
|
||||
versions as dedupe_edges_versions,
|
||||
)
|
||||
|
||||
from .invalidate_edges import (
|
||||
Prompt as InvalidateEdgesPrompt,
|
||||
Versions as InvalidateEdgesVersions,
|
||||
versions as invalidate_edges_versions,
|
||||
)
|
||||
|
||||
|
||||
class PromptLibrary(Protocol):
|
||||
extract_nodes: ExtractNodesPrompt
|
||||
dedupe_nodes: DedupeNodesPrompt
|
||||
extract_edges: ExtractEdgesPrompt
|
||||
dedupe_edges: DedupeEdgesPrompt
|
||||
invalidate_edges: InvalidateEdgesPrompt
|
||||
|
||||
|
||||
class PromptLibraryImpl(TypedDict):
|
||||
@ -39,6 +46,7 @@ class PromptLibraryImpl(TypedDict):
|
||||
dedupe_nodes: DedupeNodesVersions
|
||||
extract_edges: ExtractEdgesVersions
|
||||
dedupe_edges: DedupeEdgesVersions
|
||||
invalidate_edges: InvalidateEdgesVersions
|
||||
|
||||
|
||||
class VersionWrapper:
|
||||
@ -66,6 +74,7 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
||||
"dedupe_nodes": dedupe_nodes_versions,
|
||||
"extract_edges": extract_edges_versions,
|
||||
"dedupe_edges": dedupe_edges_versions,
|
||||
"invalidate_edges": invalidate_edges_versions,
|
||||
}
|
||||
|
||||
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)
|
||||
|
@ -3,7 +3,6 @@ from .maintenance import (
|
||||
build_episodic_edges,
|
||||
extract_new_nodes,
|
||||
clear_data,
|
||||
retrieve_relevant_schema,
|
||||
retrieve_episodes,
|
||||
)
|
||||
|
||||
@ -12,6 +11,5 @@ __all__ = [
|
||||
"build_episodic_edges",
|
||||
"extract_new_nodes",
|
||||
"clear_data",
|
||||
"retrieve_relevant_schema",
|
||||
"retrieve_episodes",
|
||||
]
|
||||
|
@ -2,7 +2,6 @@ from .edge_operations import extract_new_edges, build_episodic_edges
|
||||
from .node_operations import extract_new_nodes
|
||||
from .graph_data_operations import (
|
||||
clear_data,
|
||||
retrieve_relevant_schema,
|
||||
retrieve_episodes,
|
||||
)
|
||||
|
||||
@ -11,6 +10,6 @@ __all__ = [
|
||||
"build_episodic_edges",
|
||||
"extract_new_nodes",
|
||||
"clear_data",
|
||||
"retrieve_relevant_schema",
|
||||
"retrieve_episodes",
|
||||
"invalidate_edges",
|
||||
]
|
||||
|
@ -2,6 +2,8 @@ import json
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.nodes import EntityNode, EpisodicNode
|
||||
from core.edges import EpisodicEdge, EntityEdge
|
||||
import logging
|
||||
|
@ -17,52 +17,6 @@ async def clear_data(driver: AsyncDriver):
|
||||
await session.execute_write(delete_all)
|
||||
|
||||
|
||||
async def retrieve_relevant_schema(
|
||||
driver: AsyncDriver, query: str = None
|
||||
) -> dict[str, any]:
|
||||
async with driver.session() as session:
|
||||
summary_query = """
|
||||
MATCH (n)
|
||||
OPTIONAL MATCH (n)-[r]->(m)
|
||||
RETURN DISTINCT labels(n) AS node_labels, n.uuid AS node_uuid, n.name AS node_name,
|
||||
type(r) AS relationship_type, r.name AS relationship_name, m.name AS related_node_name
|
||||
"""
|
||||
result = await session.run(summary_query)
|
||||
records = [record async for record in result]
|
||||
|
||||
schema = {"nodes": {}, "relationships": []}
|
||||
|
||||
for record in records:
|
||||
node_label = record["node_labels"][0] # Assuming one label per node
|
||||
node_uuid = record["node_uuid"]
|
||||
node_name = record["node_name"]
|
||||
rel_type = record["relationship_type"]
|
||||
rel_name = record["relationship_name"]
|
||||
related_node = record["related_node_name"]
|
||||
|
||||
if node_name not in schema["nodes"]:
|
||||
schema["nodes"][node_name] = {
|
||||
"uuid": node_uuid,
|
||||
"label": node_label,
|
||||
"relationships": [],
|
||||
}
|
||||
|
||||
if rel_type and related_node:
|
||||
schema["nodes"][node_name]["relationships"].append(
|
||||
{"type": rel_type, "name": rel_name, "target": related_node}
|
||||
)
|
||||
schema["relationships"].append(
|
||||
{
|
||||
"source": node_name,
|
||||
"type": rel_type,
|
||||
"name": rel_name,
|
||||
"target": related_node,
|
||||
}
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
async def retrieve_episodes(
|
||||
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
|
||||
) -> list[EpisodicNode]:
|
||||
|
100
core/utils/maintenance/temporal_operations.py
Normal file
100
core/utils/maintenance/temporal_operations.py
Normal file
@ -0,0 +1,100 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from core.llm_client import LLMClient
|
||||
from core.edges import EntityEdge
|
||||
from core.nodes import EntityNode
|
||||
from core.prompts import prompt_library
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
|
||||
|
||||
|
||||
def prepare_edges_for_invalidation(
|
||||
existing_edges: list[EntityEdge],
|
||||
new_edges: list[EntityEdge],
|
||||
nodes: list[EntityNode],
|
||||
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
|
||||
existing_edges_pending_invalidation = []
|
||||
new_edges_with_nodes = []
|
||||
|
||||
existing_edges_pending_invalidation = []
|
||||
new_edges_with_nodes = []
|
||||
|
||||
for edge_list, result_list in [
|
||||
(existing_edges, existing_edges_pending_invalidation),
|
||||
(new_edges, new_edges_with_nodes),
|
||||
]:
|
||||
for edge in edge_list:
|
||||
source_node = next(
|
||||
(node for node in nodes if node.uuid == edge.source_node_uuid), None
|
||||
)
|
||||
target_node = next(
|
||||
(node for node in nodes if node.uuid == edge.target_node_uuid), None
|
||||
)
|
||||
|
||||
if source_node and target_node:
|
||||
result_list.append((source_node, edge, target_node))
|
||||
|
||||
return existing_edges_pending_invalidation, new_edges_with_nodes
|
||||
|
||||
|
||||
async def invalidate_edges(
|
||||
llm_client: LLMClient,
|
||||
existing_edges_pending_invalidation: List[NodeEdgeNodeTriplet],
|
||||
new_edges: List[NodeEdgeNodeTriplet],
|
||||
) -> List[EntityEdge]:
|
||||
invalidated_edges = []
|
||||
|
||||
context = prepare_invalidation_context(
|
||||
existing_edges_pending_invalidation, new_edges
|
||||
)
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.invalidate_edges.v1(context)
|
||||
)
|
||||
|
||||
edges_to_invalidate = llm_response.get("invalidated_edges", [])
|
||||
invalidated_edges = process_edge_invalidation_llm_response(
|
||||
edges_to_invalidate, existing_edges_pending_invalidation
|
||||
)
|
||||
|
||||
return invalidated_edges
|
||||
|
||||
|
||||
def prepare_invalidation_context(
|
||||
existing_edges: List[NodeEdgeNodeTriplet], new_edges: List[NodeEdgeNodeTriplet]
|
||||
) -> dict:
|
||||
return {
|
||||
"existing_edges": [
|
||||
f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})"
|
||||
for source_node, edge, target_node in sorted(
|
||||
existing_edges, key=lambda x: x[1].created_at, reverse=True
|
||||
)
|
||||
],
|
||||
"new_edges": [
|
||||
f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})"
|
||||
for source_node, edge, target_node in sorted(
|
||||
new_edges, key=lambda x: x[1].created_at, reverse=True
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def process_edge_invalidation_llm_response(
|
||||
edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet]
|
||||
) -> List[EntityEdge]:
|
||||
invalidated_edges = []
|
||||
for edge_to_invalidate in edges_to_invalidate:
|
||||
edge_uuid = edge_to_invalidate["edge_uuid"]
|
||||
edge_to_update = next(
|
||||
(edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid),
|
||||
None,
|
||||
)
|
||||
if edge_to_update:
|
||||
edge_to_update.expired_at = datetime.now()
|
||||
invalidated_edges.append(edge_to_update)
|
||||
logger.info(
|
||||
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Reason: {edge_to_invalidate['reason']}"
|
||||
)
|
||||
return invalidated_edges
|
@ -271,4 +271,4 @@ async def get_relevant_edges(
|
||||
|
||||
logger.info(f"Found relevant nodes: {relevant_edges.keys()}")
|
||||
|
||||
return relevant_edges.values()
|
||||
return list(relevant_edges.values())
|
||||
|
@ -42,7 +42,7 @@ async def main():
|
||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
await clear_data(client.driver)
|
||||
messages = parse_podcast_messages()
|
||||
for i, message in enumerate(messages[3:14]):
|
||||
for i, message in enumerate(messages[3:50]):
|
||||
await client.add_episode(
|
||||
name=f"Message {i}",
|
||||
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
|
||||
|
3054
poetry.lock
generated
3054
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -10,17 +10,11 @@ python = "^3.10"
|
||||
pydantic = "^2.8.2"
|
||||
fastapi = "^0.112.0"
|
||||
neo4j = "^5.23.0"
|
||||
chromadb = "^0.5.5"
|
||||
sentence-transformers = "^3.0.1"
|
||||
diskcache = "^5.6.3"
|
||||
tiktoken = "^0.7.0"
|
||||
deepeval = "^0.21.74"
|
||||
arrow = "^1.3.0"
|
||||
groq = "^0.9.0"
|
||||
openai = "^1.38.0"
|
||||
tqdm = "^4.66.4"
|
||||
python-dotenv = "^1.0.1"
|
||||
pandas = "^2.2.2"
|
||||
pytest-asyncio = "^0.23.8"
|
||||
pytest-xdist = "^3.6.1"
|
||||
pytest = "^8.3.2"
|
||||
@ -29,3 +23,6 @@ pytest = "^8.3.2"
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["."]
|
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
markers =
|
||||
integration: marks tests as integration tests
|
16
runner.py
16
runner.py
@ -49,19 +49,29 @@ async def main():
|
||||
)
|
||||
await client.add_episode(
|
||||
name="Message 2",
|
||||
episode_body="Paul: I own many bananas",
|
||||
episode_body="Paul: I hate apples now",
|
||||
source_description="WhatsApp Message",
|
||||
)
|
||||
await client.add_episode(
|
||||
name="Message 3",
|
||||
episode_body="Assistant: The best type of apples available are Fuji apples",
|
||||
episode_body="Jane: I am married to Paul",
|
||||
source_description="WhatsApp Message",
|
||||
)
|
||||
await client.add_episode(
|
||||
name="Message 4",
|
||||
episode_body="Paul: Oh, I actually hate those",
|
||||
episode_body="Paul: I have divorced Jane",
|
||||
source_description="WhatsApp Message",
|
||||
)
|
||||
# await client.add_episode(
|
||||
# name="Message 3",
|
||||
# episode_body="Assistant: The best type of apples available are Fuji apples",
|
||||
# source_description="WhatsApp Message",
|
||||
# )
|
||||
# await client.add_episode(
|
||||
# name="Message 4",
|
||||
# episode_body="Paul: Oh, I actually hate those",
|
||||
# source_description="WhatsApp Message",
|
||||
# )
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
@ -3,6 +3,9 @@ import sys
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
233
tests/utils/maintenance/test_temporal_operations.py
Normal file
233
tests/utils/maintenance/test_temporal_operations.py
Normal file
@ -0,0 +1,233 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from core.utils.maintenance.temporal_operations import (
|
||||
prepare_edges_for_invalidation,
|
||||
prepare_invalidation_context,
|
||||
)
|
||||
from core.edges import EntityEdge
|
||||
from core.nodes import EntityNode
|
||||
|
||||
|
||||
# Helper function to create test data
|
||||
def create_test_data():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid="1", name="Node1", labels=["Person"], created_at=now)
|
||||
node2 = EntityNode(uuid="2", name="Node2", labels=["Person"], created_at=now)
|
||||
node3 = EntityNode(uuid="3", name="Node3", labels=["Person"], created_at=now)
|
||||
|
||||
# Create edges
|
||||
existing_edge1 = EntityEdge(
|
||||
uuid="e1",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="2",
|
||||
name="KNOWS",
|
||||
fact="Node1 knows Node2",
|
||||
created_at=now,
|
||||
)
|
||||
existing_edge2 = EntityEdge(
|
||||
uuid="e2",
|
||||
source_node_uuid="2",
|
||||
target_node_uuid="3",
|
||||
name="LIKES",
|
||||
fact="Node2 likes Node3",
|
||||
created_at=now,
|
||||
)
|
||||
new_edge1 = EntityEdge(
|
||||
uuid="e3",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="3",
|
||||
name="WORKS_WITH",
|
||||
fact="Node1 works with Node3",
|
||||
created_at=now,
|
||||
)
|
||||
new_edge2 = EntityEdge(
|
||||
uuid="e4",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="2",
|
||||
name="DISLIKES",
|
||||
fact="Node1 dislikes Node2",
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
return {
|
||||
"nodes": [node1, node2, node3],
|
||||
"existing_edges": [existing_edge1, existing_edge2],
|
||||
"new_edges": [new_edge1, new_edge2],
|
||||
}
|
||||
|
||||
|
||||
def test_prepare_edges_for_invalidation_basic():
|
||||
test_data = create_test_data()
|
||||
|
||||
existing_edges_pending_invalidation, new_edges_with_nodes = (
|
||||
prepare_edges_for_invalidation(
|
||||
test_data["existing_edges"], test_data["new_edges"], test_data["nodes"]
|
||||
)
|
||||
)
|
||||
|
||||
assert len(existing_edges_pending_invalidation) == 2
|
||||
assert len(new_edges_with_nodes) == 2
|
||||
|
||||
# Check if the edges are correctly associated with nodes
|
||||
for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes:
|
||||
assert isinstance(edge_with_nodes[0], EntityNode)
|
||||
assert isinstance(edge_with_nodes[1], EntityEdge)
|
||||
assert isinstance(edge_with_nodes[2], EntityNode)
|
||||
|
||||
|
||||
def test_prepare_edges_for_invalidation_no_existing_edges():
|
||||
test_data = create_test_data()
|
||||
|
||||
existing_edges_pending_invalidation, new_edges_with_nodes = (
|
||||
prepare_edges_for_invalidation([], test_data["new_edges"], test_data["nodes"])
|
||||
)
|
||||
|
||||
assert len(existing_edges_pending_invalidation) == 0
|
||||
assert len(new_edges_with_nodes) == 2
|
||||
|
||||
|
||||
def test_prepare_edges_for_invalidation_no_new_edges():
|
||||
test_data = create_test_data()
|
||||
|
||||
existing_edges_pending_invalidation, new_edges_with_nodes = (
|
||||
prepare_edges_for_invalidation(
|
||||
test_data["existing_edges"], [], test_data["nodes"]
|
||||
)
|
||||
)
|
||||
|
||||
assert len(existing_edges_pending_invalidation) == 2
|
||||
assert len(new_edges_with_nodes) == 0
|
||||
|
||||
|
||||
def test_prepare_edges_for_invalidation_missing_nodes():
|
||||
test_data = create_test_data()
|
||||
|
||||
# Remove one node to simulate a missing node scenario
|
||||
nodes = test_data["nodes"][:-1]
|
||||
|
||||
existing_edges_pending_invalidation, new_edges_with_nodes = (
|
||||
prepare_edges_for_invalidation(
|
||||
test_data["existing_edges"], test_data["new_edges"], nodes
|
||||
)
|
||||
)
|
||||
|
||||
assert len(existing_edges_pending_invalidation) == 1
|
||||
assert len(new_edges_with_nodes) == 1
|
||||
|
||||
|
||||
def test_prepare_invalidation_context():
|
||||
# Create test data
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid="1", name="Node1", labels=["Person"], created_at=now)
|
||||
node2 = EntityNode(uuid="2", name="Node2", labels=["Person"], created_at=now)
|
||||
node3 = EntityNode(uuid="3", name="Node3", labels=["Person"], created_at=now)
|
||||
|
||||
# Create edges
|
||||
edge1 = EntityEdge(
|
||||
uuid="e1",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="2",
|
||||
name="KNOWS",
|
||||
fact="Node1 knows Node2",
|
||||
created_at=now,
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
uuid="e2",
|
||||
source_node_uuid="2",
|
||||
target_node_uuid="3",
|
||||
name="LIKES",
|
||||
fact="Node2 likes Node3",
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
# Create NodeEdgeNodeTriplet objects
|
||||
existing_edge = (node1, edge1, node2)
|
||||
new_edge = (node2, edge2, node3)
|
||||
|
||||
# Prepare test input
|
||||
existing_edges = [existing_edge]
|
||||
new_edges = [new_edge]
|
||||
|
||||
# Call the function
|
||||
result = prepare_invalidation_context(existing_edges, new_edges)
|
||||
|
||||
# Assert the result
|
||||
assert isinstance(result, dict)
|
||||
assert "existing_edges" in result
|
||||
assert "new_edges" in result
|
||||
assert len(result["existing_edges"]) == 1
|
||||
assert len(result["new_edges"]) == 1
|
||||
|
||||
# Check the format of the existing edge
|
||||
existing_edge_str = result["existing_edges"][0]
|
||||
assert edge1.uuid in existing_edge_str
|
||||
assert node1.name in existing_edge_str
|
||||
assert edge1.name in existing_edge_str
|
||||
assert node2.name in existing_edge_str
|
||||
assert edge1.created_at.isoformat() in existing_edge_str
|
||||
|
||||
# Check the format of the new edge
|
||||
new_edge_str = result["new_edges"][0]
|
||||
assert edge2.uuid in new_edge_str
|
||||
assert node2.name in new_edge_str
|
||||
assert edge2.name in new_edge_str
|
||||
assert node3.name in new_edge_str
|
||||
assert edge2.created_at.isoformat() in new_edge_str
|
||||
|
||||
|
||||
def test_prepare_invalidation_context_empty_input():
|
||||
result = prepare_invalidation_context([], [])
|
||||
assert isinstance(result, dict)
|
||||
assert "existing_edges" in result
|
||||
assert "new_edges" in result
|
||||
assert len(result["existing_edges"]) == 0
|
||||
assert len(result["new_edges"]) == 0
|
||||
|
||||
|
||||
def test_prepare_invalidation_context_sorting():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid="1", name="Node1", labels=["Person"], created_at=now)
|
||||
node2 = EntityNode(uuid="2", name="Node2", labels=["Person"], created_at=now)
|
||||
|
||||
# Create edges with different timestamps
|
||||
edge1 = EntityEdge(
|
||||
uuid="e1",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="2",
|
||||
name="KNOWS",
|
||||
fact="Node1 knows Node2",
|
||||
created_at=now,
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
uuid="e2",
|
||||
source_node_uuid="2",
|
||||
target_node_uuid="1",
|
||||
name="LIKES",
|
||||
fact="Node2 likes Node1",
|
||||
created_at=now + timedelta(hours=1),
|
||||
)
|
||||
|
||||
edge_with_nodes1 = (node1, edge1, node2)
|
||||
edge_with_nodes2 = (node2, edge2, node1)
|
||||
|
||||
# Prepare test input
|
||||
existing_edges = [edge_with_nodes1, edge_with_nodes2]
|
||||
|
||||
# Call the function
|
||||
result = prepare_invalidation_context(existing_edges, [])
|
||||
|
||||
# Assert the result
|
||||
assert len(result["existing_edges"]) == 2
|
||||
assert edge2.uuid in result["existing_edges"][0] # The newer edge should be first
|
||||
assert edge1.uuid in result["existing_edges"][1] # The older edge should be second
|
||||
|
||||
|
||||
# Run the tests
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
306
tests/utils/maintenance/test_temporal_operations_int.py
Normal file
306
tests/utils/maintenance/test_temporal_operations_int.py
Normal file
@ -0,0 +1,306 @@
|
||||
import pytest
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from core.utils.maintenance.temporal_operations import (
|
||||
invalidate_edges,
|
||||
)
|
||||
from core.edges import EntityEdge
|
||||
from core.nodes import EntityNode
|
||||
from core.llm_client import OpenAIClient, LLMConfig
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def setup_llm_client():
|
||||
return OpenAIClient(
|
||||
LLMConfig(
|
||||
api_key=os.getenv("TEST_OPENAI_API_KEY"),
|
||||
model=os.getenv("TEST_OPENAI_MODEL"),
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Helper function to create test data
|
||||
def create_test_data():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid="1", name="Alice", labels=["Person"], created_at=now)
|
||||
node2 = EntityNode(uuid="2", name="Bob", labels=["Person"], created_at=now)
|
||||
|
||||
# Create edges
|
||||
edge1 = EntityEdge(
|
||||
uuid="e1",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="2",
|
||||
name="LIKES",
|
||||
fact="Alice likes Bob",
|
||||
created_at=now - timedelta(days=1),
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
uuid="e2",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="2",
|
||||
name="DISLIKES",
|
||||
fact="Alice dislikes Bob",
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
existing_edge = (node1, edge1, node2)
|
||||
new_edge = (node1, edge2, node2)
|
||||
|
||||
return existing_edge, new_edge
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges():
|
||||
existing_edge, new_edge = create_test_data()
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), [existing_edge], [new_edge]
|
||||
)
|
||||
|
||||
assert len(invalidated_edges) == 1
|
||||
assert invalidated_edges[0].uuid == existing_edge[1].uuid
|
||||
assert invalidated_edges[0].expired_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_no_invalidation():
|
||||
existing_edge, _ = create_test_data()
|
||||
|
||||
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [])
|
||||
|
||||
assert len(invalidated_edges) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_multiple_existing():
|
||||
existing_edge1, new_edge = create_test_data()
|
||||
existing_edge2, _ = create_test_data()
|
||||
existing_edge2[1].uuid = "e3"
|
||||
existing_edge2[1].name = "KNOWS"
|
||||
existing_edge2[1].fact = "Alice knows Bob"
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), [existing_edge1, existing_edge2], [new_edge]
|
||||
)
|
||||
|
||||
assert len(invalidated_edges) == 1
|
||||
assert invalidated_edges[0].uuid == existing_edge1[1].uuid
|
||||
assert invalidated_edges[0].expired_at is not None
|
||||
|
||||
|
||||
# Helper function to create more complex test data
|
||||
def create_complex_test_data():
|
||||
now = datetime.now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid="1", name="Alice", labels=["Person"], created_at=now)
|
||||
node2 = EntityNode(uuid="2", name="Bob", labels=["Person"], created_at=now)
|
||||
node3 = EntityNode(uuid="3", name="Charlie", labels=["Person"], created_at=now)
|
||||
node4 = EntityNode(
|
||||
uuid="4", name="Company XYZ", labels=["Organization"], created_at=now
|
||||
)
|
||||
|
||||
# Create edges
|
||||
edge1 = EntityEdge(
|
||||
uuid="e1",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="2",
|
||||
name="LIKES",
|
||||
fact="Alice likes Bob",
|
||||
created_at=now - timedelta(days=5),
|
||||
)
|
||||
edge2 = EntityEdge(
|
||||
uuid="e2",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="3",
|
||||
name="FRIENDS_WITH",
|
||||
fact="Alice is friends with Charlie",
|
||||
created_at=now - timedelta(days=3),
|
||||
)
|
||||
edge3 = EntityEdge(
|
||||
uuid="e3",
|
||||
source_node_uuid="2",
|
||||
target_node_uuid="4",
|
||||
name="WORKS_FOR",
|
||||
fact="Bob works for Company XYZ",
|
||||
created_at=now - timedelta(days=2),
|
||||
)
|
||||
|
||||
existing_edge1 = (node1, edge1, node2)
|
||||
existing_edge2 = (node1, edge2, node3)
|
||||
existing_edge3 = (node2, edge3, node4)
|
||||
|
||||
return [existing_edge1, existing_edge2, existing_edge3], [
|
||||
node1,
|
||||
node2,
|
||||
node3,
|
||||
node4,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_complex():
|
||||
existing_edges, nodes = create_complex_test_data()
|
||||
|
||||
# Create a new edge that contradicts an existing one
|
||||
new_edge = (
|
||||
nodes[0],
|
||||
EntityEdge(
|
||||
uuid="e4",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="2",
|
||||
name="DISLIKES",
|
||||
fact="Alice dislikes Bob",
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[1],
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), existing_edges, [new_edge]
|
||||
)
|
||||
|
||||
assert len(invalidated_edges) == 1
|
||||
assert invalidated_edges[0].uuid == "e1"
|
||||
assert invalidated_edges[0].expired_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_temporal_update():
|
||||
existing_edges, nodes = create_complex_test_data()
|
||||
|
||||
# Create a new edge that updates an existing one with new information
|
||||
new_edge = (
|
||||
nodes[1],
|
||||
EntityEdge(
|
||||
uuid="e5",
|
||||
source_node_uuid="2",
|
||||
target_node_uuid="4",
|
||||
name="LEFT_JOB",
|
||||
fact="Bob left his job at Company XYZ",
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[3],
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), existing_edges, [new_edge]
|
||||
)
|
||||
|
||||
assert len(invalidated_edges) == 1
|
||||
assert invalidated_edges[0].uuid == "e3"
|
||||
assert invalidated_edges[0].expired_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_multiple_invalidations():
|
||||
existing_edges, nodes = create_complex_test_data()
|
||||
|
||||
# Create new edges that invalidate multiple existing edges
|
||||
new_edge1 = (
|
||||
nodes[0],
|
||||
EntityEdge(
|
||||
uuid="e6",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="2",
|
||||
name="ENEMIES_WITH",
|
||||
fact="Alice and Bob are now enemies",
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[1],
|
||||
)
|
||||
new_edge2 = (
|
||||
nodes[0],
|
||||
EntityEdge(
|
||||
uuid="e7",
|
||||
source_node_uuid="1",
|
||||
target_node_uuid="3",
|
||||
name="ENDED_FRIENDSHIP",
|
||||
fact="Alice ended her friendship with Charlie",
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[2],
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), existing_edges, [new_edge1, new_edge2]
|
||||
)
|
||||
|
||||
assert len(invalidated_edges) == 2
|
||||
assert set(edge.uuid for edge in invalidated_edges) == {"e1", "e2"}
|
||||
for edge in invalidated_edges:
|
||||
assert edge.expired_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_no_effect():
|
||||
existing_edges, nodes = create_complex_test_data()
|
||||
|
||||
# Create a new edge that doesn't invalidate any existing edges
|
||||
new_edge = (
|
||||
nodes[2],
|
||||
EntityEdge(
|
||||
uuid="e8",
|
||||
source_node_uuid="3",
|
||||
target_node_uuid="4",
|
||||
name="APPLIED_TO",
|
||||
fact="Charlie applied to Company XYZ",
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[3],
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), existing_edges, [new_edge]
|
||||
)
|
||||
|
||||
assert len(invalidated_edges) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_partial_update():
|
||||
existing_edges, nodes = create_complex_test_data()
|
||||
|
||||
# Create a new edge that partially updates an existing one
|
||||
new_edge = (
|
||||
nodes[1],
|
||||
EntityEdge(
|
||||
uuid="e9",
|
||||
source_node_uuid="2",
|
||||
target_node_uuid="4",
|
||||
name="CHANGED_POSITION",
|
||||
fact="Bob changed his position at Company XYZ",
|
||||
created_at=datetime.now(),
|
||||
),
|
||||
nodes[3],
|
||||
)
|
||||
|
||||
invalidated_edges = await invalidate_edges(
|
||||
setup_llm_client(), existing_edges, [new_edge]
|
||||
)
|
||||
|
||||
assert (
|
||||
len(invalidated_edges) == 0
|
||||
) # The existing edge is not invalidated, just updated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integration
|
||||
async def test_invalidate_edges_empty_inputs():
|
||||
invalidated_edges = await invalidate_edges(setup_llm_client(), [], [])
|
||||
|
||||
assert len(invalidated_edges) == 0
|
Loading…
x
Reference in New Issue
Block a user