feat: Initial version of temporal invalidation + tests (#8)

* feat: Initial version of temporal invalidation + tests

* fix: dont run int tests on CI

* fix: dont run int tests on CI

* fix: dont run int tests on CI

* fix: time of day issue

* fix: running non int tests in ci

* fix: running non int tests in ci

* fix: running non int tests in ci

* fix: running non int tests in ci

* fix: running non int tests in ci

* fix: running non int tests in ci

* fix: running non int tests in ci

* revert: Tests structural changes

* chore: Remove idea file

* chore: Get rid of NodesWithEdges class and define a triplet type instead
This commit is contained in:
Pavlo Paliychuk 2024-08-20 16:29:19 -04:00 committed by GitHub
parent 40e74a2e97
commit a6fd0ddb75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 877 additions and 3073 deletions

39
.github/workflows/unit_tests.yml vendored Normal file
View File

@ -0,0 +1,39 @@
name: Unit Tests
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Load cached Poetry installation
uses: actions/cache@v3
with:
path: ~/.local
key: poetry-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
- name: Load cached dependencies
uses: actions/cache@v3
with:
path: .venv
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
run: poetry install --no-interaction --no-root
- name: Run non-integration tests
env:
PYTHONPATH: ${{ github.workspace }}
run: |
poetry run pytest -m "not integration"

4
.gitignore vendored
View File

@ -158,5 +158,5 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
.vscode/

6
conftest.py Normal file
View File

@ -0,0 +1,6 @@
import sys
import os
# This code adds the project root directory to the Python path, allowing imports to work correctly when running tests.
# Without this file, you might encounter ModuleNotFoundError when trying to import modules from your project, especially when running tests.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))

View File

@ -8,18 +8,22 @@ import os
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge, EpisodicEdge
from core.edges import EntityEdge, EpisodicEdge
from core.utils import (
build_episodic_edges,
retrieve_relevant_schema,
extract_new_edges,
extract_new_nodes,
clear_data,
retrieve_episodes,
)
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.edge_operations import (
extract_edges,
dedupe_extracted_edges,
)
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
prepare_edges_for_invalidation,
invalidate_edges,
)
from core.utils.search.search_utils import (
edge_similarity_search,
entity_fulltext_search,
@ -59,21 +63,6 @@ class Graphiti:
"""Retrieve the last n episodic nodes from the graph"""
return await retrieve_episodes(self.driver, last_n, sources)
async def retrieve_relevant_schema(self, query: str = None) -> dict[str, any]:
"""Retrieve relevant nodes and edges to a specific query"""
return await retrieve_relevant_schema(self.driver, query)
...
# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
async def add_episode(
self,
name: str,
@ -102,7 +91,6 @@ class Graphiti:
created_at=now,
valid_at=reference_time,
)
# relevant_schema = await self.retrieve_relevant_schema(episode.content)
extracted_nodes = await extract_nodes(
self.llm_client, episode, previous_episodes
@ -139,13 +127,32 @@ class Graphiti:
f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}"
)
new_edges = await dedupe_extracted_edges(
deduped_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")
(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
) = prepare_edges_for_invalidation(
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
)
entity_edges.extend(new_edges)
invalidated_edges = await invalidate_edges(
self.llm_client,
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
)
entity_edges.extend(invalidated_edges)
logger.info(
f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}"
)
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")
entity_edges.extend(deduped_edges)
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them

View File

@ -71,7 +71,9 @@ class EntityNode(Node):
name_embedding: list[float] | None = Field(
default=None, description="embedding of the name"
)
summary: str = Field(description="regional summary of surrounding edges")
summary: str = Field(
description="regional summary of surrounding edges", default_factory=str
)
async def update_summary(self, driver: AsyncDriver): ...

View File

@ -0,0 +1,50 @@
from typing import Protocol, TypedDict
from .models import Message, PromptVersion, PromptFunction
class Prompt(Protocol):
v1: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based on newer information.",
),
Message(
role="user",
content=f"""
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
Do not invalidate relationships merely because they weren't mentioned in new edges.
Existing Edges (sorted by timestamp, newest first):
{context['existing_edges']}
New Edges:
{context['new_edges']}
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (TIMESTAMP)"
For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{
"invalidated_edges": [
{{
"edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
"reason": "Brief explanation of why this edge is being invalidated"
}}
]
}}
If no relationships need to be invalidated, return an empty list for "invalidated_edges".
""",
),
]
versions: Versions = {"v1": v1}

View File

@ -26,12 +26,19 @@ from .dedupe_edges import (
versions as dedupe_edges_versions,
)
from .invalidate_edges import (
Prompt as InvalidateEdgesPrompt,
Versions as InvalidateEdgesVersions,
versions as invalidate_edges_versions,
)
class PromptLibrary(Protocol):
extract_nodes: ExtractNodesPrompt
dedupe_nodes: DedupeNodesPrompt
extract_edges: ExtractEdgesPrompt
dedupe_edges: DedupeEdgesPrompt
invalidate_edges: InvalidateEdgesPrompt
class PromptLibraryImpl(TypedDict):
@ -39,6 +46,7 @@ class PromptLibraryImpl(TypedDict):
dedupe_nodes: DedupeNodesVersions
extract_edges: ExtractEdgesVersions
dedupe_edges: DedupeEdgesVersions
invalidate_edges: InvalidateEdgesVersions
class VersionWrapper:
@ -66,6 +74,7 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
"dedupe_nodes": dedupe_nodes_versions,
"extract_edges": extract_edges_versions,
"dedupe_edges": dedupe_edges_versions,
"invalidate_edges": invalidate_edges_versions,
}
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)

View File

@ -3,7 +3,6 @@ from .maintenance import (
build_episodic_edges,
extract_new_nodes,
clear_data,
retrieve_relevant_schema,
retrieve_episodes,
)
@ -12,6 +11,5 @@ __all__ = [
"build_episodic_edges",
"extract_new_nodes",
"clear_data",
"retrieve_relevant_schema",
"retrieve_episodes",
]

View File

@ -2,7 +2,6 @@ from .edge_operations import extract_new_edges, build_episodic_edges
from .node_operations import extract_new_nodes
from .graph_data_operations import (
clear_data,
retrieve_relevant_schema,
retrieve_episodes,
)
@ -11,6 +10,6 @@ __all__ = [
"build_episodic_edges",
"extract_new_nodes",
"clear_data",
"retrieve_relevant_schema",
"retrieve_episodes",
"invalidate_edges",
]

View File

@ -2,6 +2,8 @@ import json
from typing import List
from datetime import datetime
from pydantic import BaseModel
from core.nodes import EntityNode, EpisodicNode
from core.edges import EpisodicEdge, EntityEdge
import logging

View File

@ -17,52 +17,6 @@ async def clear_data(driver: AsyncDriver):
await session.execute_write(delete_all)
async def retrieve_relevant_schema(
driver: AsyncDriver, query: str = None
) -> dict[str, any]:
async with driver.session() as session:
summary_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]->(m)
RETURN DISTINCT labels(n) AS node_labels, n.uuid AS node_uuid, n.name AS node_name,
type(r) AS relationship_type, r.name AS relationship_name, m.name AS related_node_name
"""
result = await session.run(summary_query)
records = [record async for record in result]
schema = {"nodes": {}, "relationships": []}
for record in records:
node_label = record["node_labels"][0] # Assuming one label per node
node_uuid = record["node_uuid"]
node_name = record["node_name"]
rel_type = record["relationship_type"]
rel_name = record["relationship_name"]
related_node = record["related_node_name"]
if node_name not in schema["nodes"]:
schema["nodes"][node_name] = {
"uuid": node_uuid,
"label": node_label,
"relationships": [],
}
if rel_type and related_node:
schema["nodes"][node_name]["relationships"].append(
{"type": rel_type, "name": rel_name, "target": related_node}
)
schema["relationships"].append(
{
"source": node_name,
"type": rel_type,
"name": rel_name,
"target": related_node,
}
)
return schema
async def retrieve_episodes(
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
) -> list[EpisodicNode]:

View File

@ -0,0 +1,100 @@
from datetime import datetime
from typing import List
from core.llm_client import LLMClient
from core.edges import EntityEdge
from core.nodes import EntityNode
from core.prompts import prompt_library
import logging
logger = logging.getLogger(__name__)
NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
def prepare_edges_for_invalidation(
existing_edges: list[EntityEdge],
new_edges: list[EntityEdge],
nodes: list[EntityNode],
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
existing_edges_pending_invalidation = []
new_edges_with_nodes = []
existing_edges_pending_invalidation = []
new_edges_with_nodes = []
for edge_list, result_list in [
(existing_edges, existing_edges_pending_invalidation),
(new_edges, new_edges_with_nodes),
]:
for edge in edge_list:
source_node = next(
(node for node in nodes if node.uuid == edge.source_node_uuid), None
)
target_node = next(
(node for node in nodes if node.uuid == edge.target_node_uuid), None
)
if source_node and target_node:
result_list.append((source_node, edge, target_node))
return existing_edges_pending_invalidation, new_edges_with_nodes
async def invalidate_edges(
llm_client: LLMClient,
existing_edges_pending_invalidation: List[NodeEdgeNodeTriplet],
new_edges: List[NodeEdgeNodeTriplet],
) -> List[EntityEdge]:
invalidated_edges = []
context = prepare_invalidation_context(
existing_edges_pending_invalidation, new_edges
)
llm_response = await llm_client.generate_response(
prompt_library.invalidate_edges.v1(context)
)
edges_to_invalidate = llm_response.get("invalidated_edges", [])
invalidated_edges = process_edge_invalidation_llm_response(
edges_to_invalidate, existing_edges_pending_invalidation
)
return invalidated_edges
def prepare_invalidation_context(
existing_edges: List[NodeEdgeNodeTriplet], new_edges: List[NodeEdgeNodeTriplet]
) -> dict:
return {
"existing_edges": [
f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})"
for source_node, edge, target_node in sorted(
existing_edges, key=lambda x: x[1].created_at, reverse=True
)
],
"new_edges": [
f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})"
for source_node, edge, target_node in sorted(
new_edges, key=lambda x: x[1].created_at, reverse=True
)
],
}
def process_edge_invalidation_llm_response(
edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet]
) -> List[EntityEdge]:
invalidated_edges = []
for edge_to_invalidate in edges_to_invalidate:
edge_uuid = edge_to_invalidate["edge_uuid"]
edge_to_update = next(
(edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid),
None,
)
if edge_to_update:
edge_to_update.expired_at = datetime.now()
invalidated_edges.append(edge_to_update)
logger.info(
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Reason: {edge_to_invalidate['reason']}"
)
return invalidated_edges

View File

@ -271,4 +271,4 @@ async def get_relevant_edges(
logger.info(f"Found relevant nodes: {relevant_edges.keys()}")
return relevant_edges.values()
return list(relevant_edges.values())

View File

@ -42,7 +42,7 @@ async def main():
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
await clear_data(client.driver)
messages = parse_podcast_messages()
for i, message in enumerate(messages[3:14]):
for i, message in enumerate(messages[3:50]):
await client.add_episode(
name=f"Message {i}",
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",

3054
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -10,17 +10,11 @@ python = "^3.10"
pydantic = "^2.8.2"
fastapi = "^0.112.0"
neo4j = "^5.23.0"
chromadb = "^0.5.5"
sentence-transformers = "^3.0.1"
diskcache = "^5.6.3"
tiktoken = "^0.7.0"
deepeval = "^0.21.74"
arrow = "^1.3.0"
groq = "^0.9.0"
openai = "^1.38.0"
tqdm = "^4.66.4"
python-dotenv = "^1.0.1"
pandas = "^2.2.2"
pytest-asyncio = "^0.23.8"
pytest-xdist = "^3.6.1"
pytest = "^8.3.2"
@ -29,3 +23,6 @@ pytest = "^8.3.2"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
pythonpath = ["."]

3
pytest.ini Normal file
View File

@ -0,0 +1,3 @@
[pytest]
markers =
integration: marks tests as integration tests

View File

@ -49,19 +49,29 @@ async def main():
)
await client.add_episode(
name="Message 2",
episode_body="Paul: I own many bananas",
episode_body="Paul: I hate apples now",
source_description="WhatsApp Message",
)
await client.add_episode(
name="Message 3",
episode_body="Assistant: The best type of apples available are Fuji apples",
episode_body="Jane: I am married to Paul",
source_description="WhatsApp Message",
)
await client.add_episode(
name="Message 4",
episode_body="Paul: Oh, I actually hate those",
episode_body="Paul: I have divorced Jane",
source_description="WhatsApp Message",
)
# await client.add_episode(
# name="Message 3",
# episode_body="Assistant: The best type of apples available are Fuji apples",
# source_description="WhatsApp Message",
# )
# await client.add_episode(
# name="Message 4",
# episode_body="Paul: Oh, I actually hate those",
# source_description="WhatsApp Message",
# )
asyncio.run(main())

View File

@ -3,6 +3,9 @@ import sys
import os
import pytest
pytestmark = pytest.mark.integration
import asyncio
from dotenv import load_dotenv

View File

@ -0,0 +1,233 @@
import pytest
from datetime import datetime, timedelta
from core.utils.maintenance.temporal_operations import (
prepare_edges_for_invalidation,
prepare_invalidation_context,
)
from core.edges import EntityEdge
from core.nodes import EntityNode
# Helper function to create test data
def create_test_data():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid="1", name="Node1", labels=["Person"], created_at=now)
node2 = EntityNode(uuid="2", name="Node2", labels=["Person"], created_at=now)
node3 = EntityNode(uuid="3", name="Node3", labels=["Person"], created_at=now)
# Create edges
existing_edge1 = EntityEdge(
uuid="e1",
source_node_uuid="1",
target_node_uuid="2",
name="KNOWS",
fact="Node1 knows Node2",
created_at=now,
)
existing_edge2 = EntityEdge(
uuid="e2",
source_node_uuid="2",
target_node_uuid="3",
name="LIKES",
fact="Node2 likes Node3",
created_at=now,
)
new_edge1 = EntityEdge(
uuid="e3",
source_node_uuid="1",
target_node_uuid="3",
name="WORKS_WITH",
fact="Node1 works with Node3",
created_at=now,
)
new_edge2 = EntityEdge(
uuid="e4",
source_node_uuid="1",
target_node_uuid="2",
name="DISLIKES",
fact="Node1 dislikes Node2",
created_at=now,
)
return {
"nodes": [node1, node2, node3],
"existing_edges": [existing_edge1, existing_edge2],
"new_edges": [new_edge1, new_edge2],
}
def test_prepare_edges_for_invalidation_basic():
test_data = create_test_data()
existing_edges_pending_invalidation, new_edges_with_nodes = (
prepare_edges_for_invalidation(
test_data["existing_edges"], test_data["new_edges"], test_data["nodes"]
)
)
assert len(existing_edges_pending_invalidation) == 2
assert len(new_edges_with_nodes) == 2
# Check if the edges are correctly associated with nodes
for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes:
assert isinstance(edge_with_nodes[0], EntityNode)
assert isinstance(edge_with_nodes[1], EntityEdge)
assert isinstance(edge_with_nodes[2], EntityNode)
def test_prepare_edges_for_invalidation_no_existing_edges():
test_data = create_test_data()
existing_edges_pending_invalidation, new_edges_with_nodes = (
prepare_edges_for_invalidation([], test_data["new_edges"], test_data["nodes"])
)
assert len(existing_edges_pending_invalidation) == 0
assert len(new_edges_with_nodes) == 2
def test_prepare_edges_for_invalidation_no_new_edges():
test_data = create_test_data()
existing_edges_pending_invalidation, new_edges_with_nodes = (
prepare_edges_for_invalidation(
test_data["existing_edges"], [], test_data["nodes"]
)
)
assert len(existing_edges_pending_invalidation) == 2
assert len(new_edges_with_nodes) == 0
def test_prepare_edges_for_invalidation_missing_nodes():
test_data = create_test_data()
# Remove one node to simulate a missing node scenario
nodes = test_data["nodes"][:-1]
existing_edges_pending_invalidation, new_edges_with_nodes = (
prepare_edges_for_invalidation(
test_data["existing_edges"], test_data["new_edges"], nodes
)
)
assert len(existing_edges_pending_invalidation) == 1
assert len(new_edges_with_nodes) == 1
def test_prepare_invalidation_context():
# Create test data
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid="1", name="Node1", labels=["Person"], created_at=now)
node2 = EntityNode(uuid="2", name="Node2", labels=["Person"], created_at=now)
node3 = EntityNode(uuid="3", name="Node3", labels=["Person"], created_at=now)
# Create edges
edge1 = EntityEdge(
uuid="e1",
source_node_uuid="1",
target_node_uuid="2",
name="KNOWS",
fact="Node1 knows Node2",
created_at=now,
)
edge2 = EntityEdge(
uuid="e2",
source_node_uuid="2",
target_node_uuid="3",
name="LIKES",
fact="Node2 likes Node3",
created_at=now,
)
# Create NodeEdgeNodeTriplet objects
existing_edge = (node1, edge1, node2)
new_edge = (node2, edge2, node3)
# Prepare test input
existing_edges = [existing_edge]
new_edges = [new_edge]
# Call the function
result = prepare_invalidation_context(existing_edges, new_edges)
# Assert the result
assert isinstance(result, dict)
assert "existing_edges" in result
assert "new_edges" in result
assert len(result["existing_edges"]) == 1
assert len(result["new_edges"]) == 1
# Check the format of the existing edge
existing_edge_str = result["existing_edges"][0]
assert edge1.uuid in existing_edge_str
assert node1.name in existing_edge_str
assert edge1.name in existing_edge_str
assert node2.name in existing_edge_str
assert edge1.created_at.isoformat() in existing_edge_str
# Check the format of the new edge
new_edge_str = result["new_edges"][0]
assert edge2.uuid in new_edge_str
assert node2.name in new_edge_str
assert edge2.name in new_edge_str
assert node3.name in new_edge_str
assert edge2.created_at.isoformat() in new_edge_str
def test_prepare_invalidation_context_empty_input():
result = prepare_invalidation_context([], [])
assert isinstance(result, dict)
assert "existing_edges" in result
assert "new_edges" in result
assert len(result["existing_edges"]) == 0
assert len(result["new_edges"]) == 0
def test_prepare_invalidation_context_sorting():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid="1", name="Node1", labels=["Person"], created_at=now)
node2 = EntityNode(uuid="2", name="Node2", labels=["Person"], created_at=now)
# Create edges with different timestamps
edge1 = EntityEdge(
uuid="e1",
source_node_uuid="1",
target_node_uuid="2",
name="KNOWS",
fact="Node1 knows Node2",
created_at=now,
)
edge2 = EntityEdge(
uuid="e2",
source_node_uuid="2",
target_node_uuid="1",
name="LIKES",
fact="Node2 likes Node1",
created_at=now + timedelta(hours=1),
)
edge_with_nodes1 = (node1, edge1, node2)
edge_with_nodes2 = (node2, edge2, node1)
# Prepare test input
existing_edges = [edge_with_nodes1, edge_with_nodes2]
# Call the function
result = prepare_invalidation_context(existing_edges, [])
# Assert the result
assert len(result["existing_edges"]) == 2
assert edge2.uuid in result["existing_edges"][0] # The newer edge should be first
assert edge1.uuid in result["existing_edges"][1] # The older edge should be second
# Run the tests
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -0,0 +1,306 @@
import pytest
from datetime import datetime, timedelta
from core.utils.maintenance.temporal_operations import (
invalidate_edges,
)
from core.edges import EntityEdge
from core.nodes import EntityNode
from core.llm_client import OpenAIClient, LLMConfig
from dotenv import load_dotenv
import os
load_dotenv()
def setup_llm_client():
return OpenAIClient(
LLMConfig(
api_key=os.getenv("TEST_OPENAI_API_KEY"),
model=os.getenv("TEST_OPENAI_MODEL"),
base_url="https://api.openai.com/v1",
)
)
# Helper function to create test data
def create_test_data():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid="1", name="Alice", labels=["Person"], created_at=now)
node2 = EntityNode(uuid="2", name="Bob", labels=["Person"], created_at=now)
# Create edges
edge1 = EntityEdge(
uuid="e1",
source_node_uuid="1",
target_node_uuid="2",
name="LIKES",
fact="Alice likes Bob",
created_at=now - timedelta(days=1),
)
edge2 = EntityEdge(
uuid="e2",
source_node_uuid="1",
target_node_uuid="2",
name="DISLIKES",
fact="Alice dislikes Bob",
created_at=now,
)
existing_edge = (node1, edge1, node2)
new_edge = (node1, edge2, node2)
return existing_edge, new_edge
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges():
existing_edge, new_edge = create_test_data()
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge], [new_edge]
)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge[1].uuid
assert invalidated_edges[0].expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_no_invalidation():
existing_edge, _ = create_test_data()
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [])
assert len(invalidated_edges) == 0
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_multiple_existing():
existing_edge1, new_edge = create_test_data()
existing_edge2, _ = create_test_data()
existing_edge2[1].uuid = "e3"
existing_edge2[1].name = "KNOWS"
existing_edge2[1].fact = "Alice knows Bob"
invalidated_edges = await invalidate_edges(
setup_llm_client(), [existing_edge1, existing_edge2], [new_edge]
)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge1[1].uuid
assert invalidated_edges[0].expired_at is not None
# Helper function to create more complex test data
def create_complex_test_data():
now = datetime.now()
# Create nodes
node1 = EntityNode(uuid="1", name="Alice", labels=["Person"], created_at=now)
node2 = EntityNode(uuid="2", name="Bob", labels=["Person"], created_at=now)
node3 = EntityNode(uuid="3", name="Charlie", labels=["Person"], created_at=now)
node4 = EntityNode(
uuid="4", name="Company XYZ", labels=["Organization"], created_at=now
)
# Create edges
edge1 = EntityEdge(
uuid="e1",
source_node_uuid="1",
target_node_uuid="2",
name="LIKES",
fact="Alice likes Bob",
created_at=now - timedelta(days=5),
)
edge2 = EntityEdge(
uuid="e2",
source_node_uuid="1",
target_node_uuid="3",
name="FRIENDS_WITH",
fact="Alice is friends with Charlie",
created_at=now - timedelta(days=3),
)
edge3 = EntityEdge(
uuid="e3",
source_node_uuid="2",
target_node_uuid="4",
name="WORKS_FOR",
fact="Bob works for Company XYZ",
created_at=now - timedelta(days=2),
)
existing_edge1 = (node1, edge1, node2)
existing_edge2 = (node1, edge2, node3)
existing_edge3 = (node2, edge3, node4)
return [existing_edge1, existing_edge2, existing_edge3], [
node1,
node2,
node3,
node4,
]
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_complex():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that contradicts an existing one
new_edge = (
nodes[0],
EntityEdge(
uuid="e4",
source_node_uuid="1",
target_node_uuid="2",
name="DISLIKES",
fact="Alice dislikes Bob",
created_at=datetime.now(),
),
nodes[1],
)
invalidated_edges = await invalidate_edges(
setup_llm_client(), existing_edges, [new_edge]
)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == "e1"
assert invalidated_edges[0].expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_temporal_update():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that updates an existing one with new information
new_edge = (
nodes[1],
EntityEdge(
uuid="e5",
source_node_uuid="2",
target_node_uuid="4",
name="LEFT_JOB",
fact="Bob left his job at Company XYZ",
created_at=datetime.now(),
),
nodes[3],
)
invalidated_edges = await invalidate_edges(
setup_llm_client(), existing_edges, [new_edge]
)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == "e3"
assert invalidated_edges[0].expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_multiple_invalidations():
existing_edges, nodes = create_complex_test_data()
# Create new edges that invalidate multiple existing edges
new_edge1 = (
nodes[0],
EntityEdge(
uuid="e6",
source_node_uuid="1",
target_node_uuid="2",
name="ENEMIES_WITH",
fact="Alice and Bob are now enemies",
created_at=datetime.now(),
),
nodes[1],
)
new_edge2 = (
nodes[0],
EntityEdge(
uuid="e7",
source_node_uuid="1",
target_node_uuid="3",
name="ENDED_FRIENDSHIP",
fact="Alice ended her friendship with Charlie",
created_at=datetime.now(),
),
nodes[2],
)
invalidated_edges = await invalidate_edges(
setup_llm_client(), existing_edges, [new_edge1, new_edge2]
)
assert len(invalidated_edges) == 2
assert set(edge.uuid for edge in invalidated_edges) == {"e1", "e2"}
for edge in invalidated_edges:
assert edge.expired_at is not None
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_no_effect():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that doesn't invalidate any existing edges
new_edge = (
nodes[2],
EntityEdge(
uuid="e8",
source_node_uuid="3",
target_node_uuid="4",
name="APPLIED_TO",
fact="Charlie applied to Company XYZ",
created_at=datetime.now(),
),
nodes[3],
)
invalidated_edges = await invalidate_edges(
setup_llm_client(), existing_edges, [new_edge]
)
assert len(invalidated_edges) == 0
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_partial_update():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that partially updates an existing one
new_edge = (
nodes[1],
EntityEdge(
uuid="e9",
source_node_uuid="2",
target_node_uuid="4",
name="CHANGED_POSITION",
fact="Bob changed his position at Company XYZ",
created_at=datetime.now(),
),
nodes[3],
)
invalidated_edges = await invalidate_edges(
setup_llm_client(), existing_edges, [new_edge]
)
assert (
len(invalidated_edges) == 0
) # The existing edge is not invalidated, just updated
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_empty_inputs():
invalidated_edges = await invalidate_edges(setup_llm_client(), [], [])
assert len(invalidated_edges) == 0