mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
feat: Initial version of temporal invalidation + tests (#8)
* feat: Initial version of temporal invalidation + tests * fix: dont run int tests on CI * fix: dont run int tests on CI * fix: dont run int tests on CI * fix: time of day issue * fix: running non int tests in ci * fix: running non int tests in ci * fix: running non int tests in ci * fix: running non int tests in ci * fix: running non int tests in ci * fix: running non int tests in ci * fix: running non int tests in ci * revert: Tests structural changes * chore: Remove idea file * chore: Get rid of NodesWithEdges class and define a triplet type instead
This commit is contained in:
parent
40e74a2e97
commit
a6fd0ddb75
39
.github/workflows/unit_tests.yml
vendored
Normal file
39
.github/workflows/unit_tests.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
name: Unit Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
- name: Load cached Poetry installation
|
||||||
|
uses: actions/cache@v3
|
||||||
|
with:
|
||||||
|
path: ~/.local
|
||||||
|
key: poetry-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: snok/install-poetry@v1
|
||||||
|
with:
|
||||||
|
virtualenvs-create: true
|
||||||
|
virtualenvs-in-project: true
|
||||||
|
- name: Load cached dependencies
|
||||||
|
uses: actions/cache@v3
|
||||||
|
with:
|
||||||
|
path: .venv
|
||||||
|
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: poetry install --no-interaction --no-root
|
||||||
|
- name: Run non-integration tests
|
||||||
|
env:
|
||||||
|
PYTHONPATH: ${{ github.workspace }}
|
||||||
|
run: |
|
||||||
|
poetry run pytest -m "not integration"
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -158,5 +158,5 @@ cython_debug/
|
|||||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
6
conftest.py
Normal file
6
conftest.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# This code adds the project root directory to the Python path, allowing imports to work correctly when running tests.
|
||||||
|
# Without this file, you might encounter ModuleNotFoundError when trying to import modules from your project, especially when running tests.
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))
|
@ -8,18 +8,22 @@ import os
|
|||||||
|
|
||||||
from core.llm_client.config import EMBEDDING_DIM
|
from core.llm_client.config import EMBEDDING_DIM
|
||||||
from core.nodes import EntityNode, EpisodicNode, Node
|
from core.nodes import EntityNode, EpisodicNode, Node
|
||||||
from core.edges import EntityEdge, Edge, EpisodicEdge
|
from core.edges import EntityEdge, EpisodicEdge
|
||||||
from core.utils import (
|
from core.utils import (
|
||||||
build_episodic_edges,
|
build_episodic_edges,
|
||||||
retrieve_relevant_schema,
|
|
||||||
extract_new_edges,
|
|
||||||
extract_new_nodes,
|
|
||||||
clear_data,
|
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
|
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
|
||||||
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
|
from core.utils.maintenance.edge_operations import (
|
||||||
|
extract_edges,
|
||||||
|
dedupe_extracted_edges,
|
||||||
|
)
|
||||||
|
|
||||||
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
|
||||||
|
from core.utils.maintenance.temporal_operations import (
|
||||||
|
prepare_edges_for_invalidation,
|
||||||
|
invalidate_edges,
|
||||||
|
)
|
||||||
from core.utils.search.search_utils import (
|
from core.utils.search.search_utils import (
|
||||||
edge_similarity_search,
|
edge_similarity_search,
|
||||||
entity_fulltext_search,
|
entity_fulltext_search,
|
||||||
@ -59,21 +63,6 @@ class Graphiti:
|
|||||||
"""Retrieve the last n episodic nodes from the graph"""
|
"""Retrieve the last n episodic nodes from the graph"""
|
||||||
return await retrieve_episodes(self.driver, last_n, sources)
|
return await retrieve_episodes(self.driver, last_n, sources)
|
||||||
|
|
||||||
async def retrieve_relevant_schema(self, query: str = None) -> dict[str, any]:
|
|
||||||
"""Retrieve relevant nodes and edges to a specific query"""
|
|
||||||
return await retrieve_relevant_schema(self.driver, query)
|
|
||||||
...
|
|
||||||
|
|
||||||
# Invalidate edges that are no longer valid
|
|
||||||
async def invalidate_edges(
|
|
||||||
self,
|
|
||||||
episode: EpisodicNode,
|
|
||||||
new_nodes: list[EntityNode],
|
|
||||||
new_edges: list[EntityEdge],
|
|
||||||
relevant_schema: dict[str, any],
|
|
||||||
previous_episodes: list[EpisodicNode],
|
|
||||||
): ...
|
|
||||||
|
|
||||||
async def add_episode(
|
async def add_episode(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
@ -102,7 +91,6 @@ class Graphiti:
|
|||||||
created_at=now,
|
created_at=now,
|
||||||
valid_at=reference_time,
|
valid_at=reference_time,
|
||||||
)
|
)
|
||||||
# relevant_schema = await self.retrieve_relevant_schema(episode.content)
|
|
||||||
|
|
||||||
extracted_nodes = await extract_nodes(
|
extracted_nodes = await extract_nodes(
|
||||||
self.llm_client, episode, previous_episodes
|
self.llm_client, episode, previous_episodes
|
||||||
@ -139,13 +127,32 @@ class Graphiti:
|
|||||||
f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}"
|
f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
new_edges = await dedupe_extracted_edges(
|
deduped_edges = await dedupe_extracted_edges(
|
||||||
self.llm_client, extracted_edges, existing_edges
|
self.llm_client, extracted_edges, existing_edges
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")
|
(
|
||||||
|
old_edges_with_nodes_pending_invalidation,
|
||||||
|
new_edges_with_nodes,
|
||||||
|
) = prepare_edges_for_invalidation(
|
||||||
|
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
|
||||||
|
)
|
||||||
|
|
||||||
entity_edges.extend(new_edges)
|
invalidated_edges = await invalidate_edges(
|
||||||
|
self.llm_client,
|
||||||
|
old_edges_with_nodes_pending_invalidation,
|
||||||
|
new_edges_with_nodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_edges.extend(invalidated_edges)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")
|
||||||
|
|
||||||
|
entity_edges.extend(deduped_edges)
|
||||||
episodic_edges.extend(
|
episodic_edges.extend(
|
||||||
build_episodic_edges(
|
build_episodic_edges(
|
||||||
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
|
||||||
|
@ -71,7 +71,9 @@ class EntityNode(Node):
|
|||||||
name_embedding: list[float] | None = Field(
|
name_embedding: list[float] | None = Field(
|
||||||
default=None, description="embedding of the name"
|
default=None, description="embedding of the name"
|
||||||
)
|
)
|
||||||
summary: str = Field(description="regional summary of surrounding edges")
|
summary: str = Field(
|
||||||
|
description="regional summary of surrounding edges", default_factory=str
|
||||||
|
)
|
||||||
|
|
||||||
async def update_summary(self, driver: AsyncDriver): ...
|
async def update_summary(self, driver: AsyncDriver): ...
|
||||||
|
|
||||||
|
50
core/prompts/invalidate_edges.py
Normal file
50
core/prompts/invalidate_edges.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from typing import Protocol, TypedDict
|
||||||
|
from .models import Message, PromptVersion, PromptFunction
|
||||||
|
|
||||||
|
|
||||||
|
class Prompt(Protocol):
|
||||||
|
v1: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
|
class Versions(TypedDict):
|
||||||
|
v1: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
|
def v1(context: dict[str, any]) -> list[Message]:
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role="system",
|
||||||
|
content="You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based on newer information.",
|
||||||
|
),
|
||||||
|
Message(
|
||||||
|
role="user",
|
||||||
|
content=f"""
|
||||||
|
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
|
||||||
|
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
|
||||||
|
Do not invalidate relationships merely because they weren't mentioned in new edges.
|
||||||
|
|
||||||
|
Existing Edges (sorted by timestamp, newest first):
|
||||||
|
{context['existing_edges']}
|
||||||
|
|
||||||
|
New Edges:
|
||||||
|
{context['new_edges']}
|
||||||
|
|
||||||
|
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (TIMESTAMP)"
|
||||||
|
|
||||||
|
For each existing edge that should be invalidated, respond with a JSON object in the following format:
|
||||||
|
{{
|
||||||
|
"invalidated_edges": [
|
||||||
|
{{
|
||||||
|
"edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
|
||||||
|
"reason": "Brief explanation of why this edge is being invalidated"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
If no relationships need to be invalidated, return an empty list for "invalidated_edges".
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
versions: Versions = {"v1": v1}
|
@ -26,12 +26,19 @@ from .dedupe_edges import (
|
|||||||
versions as dedupe_edges_versions,
|
versions as dedupe_edges_versions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .invalidate_edges import (
|
||||||
|
Prompt as InvalidateEdgesPrompt,
|
||||||
|
Versions as InvalidateEdgesVersions,
|
||||||
|
versions as invalidate_edges_versions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PromptLibrary(Protocol):
|
class PromptLibrary(Protocol):
|
||||||
extract_nodes: ExtractNodesPrompt
|
extract_nodes: ExtractNodesPrompt
|
||||||
dedupe_nodes: DedupeNodesPrompt
|
dedupe_nodes: DedupeNodesPrompt
|
||||||
extract_edges: ExtractEdgesPrompt
|
extract_edges: ExtractEdgesPrompt
|
||||||
dedupe_edges: DedupeEdgesPrompt
|
dedupe_edges: DedupeEdgesPrompt
|
||||||
|
invalidate_edges: InvalidateEdgesPrompt
|
||||||
|
|
||||||
|
|
||||||
class PromptLibraryImpl(TypedDict):
|
class PromptLibraryImpl(TypedDict):
|
||||||
@ -39,6 +46,7 @@ class PromptLibraryImpl(TypedDict):
|
|||||||
dedupe_nodes: DedupeNodesVersions
|
dedupe_nodes: DedupeNodesVersions
|
||||||
extract_edges: ExtractEdgesVersions
|
extract_edges: ExtractEdgesVersions
|
||||||
dedupe_edges: DedupeEdgesVersions
|
dedupe_edges: DedupeEdgesVersions
|
||||||
|
invalidate_edges: InvalidateEdgesVersions
|
||||||
|
|
||||||
|
|
||||||
class VersionWrapper:
|
class VersionWrapper:
|
||||||
@ -66,6 +74,7 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
|||||||
"dedupe_nodes": dedupe_nodes_versions,
|
"dedupe_nodes": dedupe_nodes_versions,
|
||||||
"extract_edges": extract_edges_versions,
|
"extract_edges": extract_edges_versions,
|
||||||
"dedupe_edges": dedupe_edges_versions,
|
"dedupe_edges": dedupe_edges_versions,
|
||||||
|
"invalidate_edges": invalidate_edges_versions,
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)
|
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL)
|
||||||
|
@ -3,7 +3,6 @@ from .maintenance import (
|
|||||||
build_episodic_edges,
|
build_episodic_edges,
|
||||||
extract_new_nodes,
|
extract_new_nodes,
|
||||||
clear_data,
|
clear_data,
|
||||||
retrieve_relevant_schema,
|
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -12,6 +11,5 @@ __all__ = [
|
|||||||
"build_episodic_edges",
|
"build_episodic_edges",
|
||||||
"extract_new_nodes",
|
"extract_new_nodes",
|
||||||
"clear_data",
|
"clear_data",
|
||||||
"retrieve_relevant_schema",
|
|
||||||
"retrieve_episodes",
|
"retrieve_episodes",
|
||||||
]
|
]
|
||||||
|
@ -2,7 +2,6 @@ from .edge_operations import extract_new_edges, build_episodic_edges
|
|||||||
from .node_operations import extract_new_nodes
|
from .node_operations import extract_new_nodes
|
||||||
from .graph_data_operations import (
|
from .graph_data_operations import (
|
||||||
clear_data,
|
clear_data,
|
||||||
retrieve_relevant_schema,
|
|
||||||
retrieve_episodes,
|
retrieve_episodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -11,6 +10,6 @@ __all__ = [
|
|||||||
"build_episodic_edges",
|
"build_episodic_edges",
|
||||||
"extract_new_nodes",
|
"extract_new_nodes",
|
||||||
"clear_data",
|
"clear_data",
|
||||||
"retrieve_relevant_schema",
|
|
||||||
"retrieve_episodes",
|
"retrieve_episodes",
|
||||||
|
"invalidate_edges",
|
||||||
]
|
]
|
||||||
|
@ -2,6 +2,8 @@ import json
|
|||||||
from typing import List
|
from typing import List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.nodes import EntityNode, EpisodicNode
|
from core.nodes import EntityNode, EpisodicNode
|
||||||
from core.edges import EpisodicEdge, EntityEdge
|
from core.edges import EpisodicEdge, EntityEdge
|
||||||
import logging
|
import logging
|
||||||
|
@ -17,52 +17,6 @@ async def clear_data(driver: AsyncDriver):
|
|||||||
await session.execute_write(delete_all)
|
await session.execute_write(delete_all)
|
||||||
|
|
||||||
|
|
||||||
async def retrieve_relevant_schema(
|
|
||||||
driver: AsyncDriver, query: str = None
|
|
||||||
) -> dict[str, any]:
|
|
||||||
async with driver.session() as session:
|
|
||||||
summary_query = """
|
|
||||||
MATCH (n)
|
|
||||||
OPTIONAL MATCH (n)-[r]->(m)
|
|
||||||
RETURN DISTINCT labels(n) AS node_labels, n.uuid AS node_uuid, n.name AS node_name,
|
|
||||||
type(r) AS relationship_type, r.name AS relationship_name, m.name AS related_node_name
|
|
||||||
"""
|
|
||||||
result = await session.run(summary_query)
|
|
||||||
records = [record async for record in result]
|
|
||||||
|
|
||||||
schema = {"nodes": {}, "relationships": []}
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
node_label = record["node_labels"][0] # Assuming one label per node
|
|
||||||
node_uuid = record["node_uuid"]
|
|
||||||
node_name = record["node_name"]
|
|
||||||
rel_type = record["relationship_type"]
|
|
||||||
rel_name = record["relationship_name"]
|
|
||||||
related_node = record["related_node_name"]
|
|
||||||
|
|
||||||
if node_name not in schema["nodes"]:
|
|
||||||
schema["nodes"][node_name] = {
|
|
||||||
"uuid": node_uuid,
|
|
||||||
"label": node_label,
|
|
||||||
"relationships": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
if rel_type and related_node:
|
|
||||||
schema["nodes"][node_name]["relationships"].append(
|
|
||||||
{"type": rel_type, "name": rel_name, "target": related_node}
|
|
||||||
)
|
|
||||||
schema["relationships"].append(
|
|
||||||
{
|
|
||||||
"source": node_name,
|
|
||||||
"type": rel_type,
|
|
||||||
"name": rel_name,
|
|
||||||
"target": related_node,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return schema
|
|
||||||
|
|
||||||
|
|
||||||
async def retrieve_episodes(
|
async def retrieve_episodes(
|
||||||
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
|
driver: AsyncDriver, last_n: int, sources: list[str] | None = "messages"
|
||||||
) -> list[EpisodicNode]:
|
) -> list[EpisodicNode]:
|
||||||
|
100
core/utils/maintenance/temporal_operations.py
Normal file
100
core/utils/maintenance/temporal_operations.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
from core.llm_client import LLMClient
|
||||||
|
from core.edges import EntityEdge
|
||||||
|
from core.nodes import EntityNode
|
||||||
|
from core.prompts import prompt_library
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
NodeEdgeNodeTriplet = tuple[EntityNode, EntityEdge, EntityNode]
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_edges_for_invalidation(
|
||||||
|
existing_edges: list[EntityEdge],
|
||||||
|
new_edges: list[EntityEdge],
|
||||||
|
nodes: list[EntityNode],
|
||||||
|
) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]:
|
||||||
|
existing_edges_pending_invalidation = []
|
||||||
|
new_edges_with_nodes = []
|
||||||
|
|
||||||
|
existing_edges_pending_invalidation = []
|
||||||
|
new_edges_with_nodes = []
|
||||||
|
|
||||||
|
for edge_list, result_list in [
|
||||||
|
(existing_edges, existing_edges_pending_invalidation),
|
||||||
|
(new_edges, new_edges_with_nodes),
|
||||||
|
]:
|
||||||
|
for edge in edge_list:
|
||||||
|
source_node = next(
|
||||||
|
(node for node in nodes if node.uuid == edge.source_node_uuid), None
|
||||||
|
)
|
||||||
|
target_node = next(
|
||||||
|
(node for node in nodes if node.uuid == edge.target_node_uuid), None
|
||||||
|
)
|
||||||
|
|
||||||
|
if source_node and target_node:
|
||||||
|
result_list.append((source_node, edge, target_node))
|
||||||
|
|
||||||
|
return existing_edges_pending_invalidation, new_edges_with_nodes
|
||||||
|
|
||||||
|
|
||||||
|
async def invalidate_edges(
|
||||||
|
llm_client: LLMClient,
|
||||||
|
existing_edges_pending_invalidation: List[NodeEdgeNodeTriplet],
|
||||||
|
new_edges: List[NodeEdgeNodeTriplet],
|
||||||
|
) -> List[EntityEdge]:
|
||||||
|
invalidated_edges = []
|
||||||
|
|
||||||
|
context = prepare_invalidation_context(
|
||||||
|
existing_edges_pending_invalidation, new_edges
|
||||||
|
)
|
||||||
|
llm_response = await llm_client.generate_response(
|
||||||
|
prompt_library.invalidate_edges.v1(context)
|
||||||
|
)
|
||||||
|
|
||||||
|
edges_to_invalidate = llm_response.get("invalidated_edges", [])
|
||||||
|
invalidated_edges = process_edge_invalidation_llm_response(
|
||||||
|
edges_to_invalidate, existing_edges_pending_invalidation
|
||||||
|
)
|
||||||
|
|
||||||
|
return invalidated_edges
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_invalidation_context(
|
||||||
|
existing_edges: List[NodeEdgeNodeTriplet], new_edges: List[NodeEdgeNodeTriplet]
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"existing_edges": [
|
||||||
|
f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})"
|
||||||
|
for source_node, edge, target_node in sorted(
|
||||||
|
existing_edges, key=lambda x: x[1].created_at, reverse=True
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"new_edges": [
|
||||||
|
f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} ({edge.created_at.isoformat()})"
|
||||||
|
for source_node, edge, target_node in sorted(
|
||||||
|
new_edges, key=lambda x: x[1].created_at, reverse=True
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def process_edge_invalidation_llm_response(
|
||||||
|
edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet]
|
||||||
|
) -> List[EntityEdge]:
|
||||||
|
invalidated_edges = []
|
||||||
|
for edge_to_invalidate in edges_to_invalidate:
|
||||||
|
edge_uuid = edge_to_invalidate["edge_uuid"]
|
||||||
|
edge_to_update = next(
|
||||||
|
(edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if edge_to_update:
|
||||||
|
edge_to_update.expired_at = datetime.now()
|
||||||
|
invalidated_edges.append(edge_to_update)
|
||||||
|
logger.info(
|
||||||
|
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Reason: {edge_to_invalidate['reason']}"
|
||||||
|
)
|
||||||
|
return invalidated_edges
|
@ -271,4 +271,4 @@ async def get_relevant_edges(
|
|||||||
|
|
||||||
logger.info(f"Found relevant nodes: {relevant_edges.keys()}")
|
logger.info(f"Found relevant nodes: {relevant_edges.keys()}")
|
||||||
|
|
||||||
return relevant_edges.values()
|
return list(relevant_edges.values())
|
||||||
|
@ -42,7 +42,7 @@ async def main():
|
|||||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||||
await clear_data(client.driver)
|
await clear_data(client.driver)
|
||||||
messages = parse_podcast_messages()
|
messages = parse_podcast_messages()
|
||||||
for i, message in enumerate(messages[3:14]):
|
for i, message in enumerate(messages[3:50]):
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name=f"Message {i}",
|
name=f"Message {i}",
|
||||||
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
|
episode_body=f"{message.speaker_name} ({message.role}): {message.content}",
|
||||||
|
3054
poetry.lock
generated
3054
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -10,17 +10,11 @@ python = "^3.10"
|
|||||||
pydantic = "^2.8.2"
|
pydantic = "^2.8.2"
|
||||||
fastapi = "^0.112.0"
|
fastapi = "^0.112.0"
|
||||||
neo4j = "^5.23.0"
|
neo4j = "^5.23.0"
|
||||||
chromadb = "^0.5.5"
|
|
||||||
sentence-transformers = "^3.0.1"
|
sentence-transformers = "^3.0.1"
|
||||||
diskcache = "^5.6.3"
|
diskcache = "^5.6.3"
|
||||||
tiktoken = "^0.7.0"
|
|
||||||
deepeval = "^0.21.74"
|
|
||||||
arrow = "^1.3.0"
|
arrow = "^1.3.0"
|
||||||
groq = "^0.9.0"
|
|
||||||
openai = "^1.38.0"
|
openai = "^1.38.0"
|
||||||
tqdm = "^4.66.4"
|
|
||||||
python-dotenv = "^1.0.1"
|
python-dotenv = "^1.0.1"
|
||||||
pandas = "^2.2.2"
|
|
||||||
pytest-asyncio = "^0.23.8"
|
pytest-asyncio = "^0.23.8"
|
||||||
pytest-xdist = "^3.6.1"
|
pytest-xdist = "^3.6.1"
|
||||||
pytest = "^8.3.2"
|
pytest = "^8.3.2"
|
||||||
@ -29,3 +23,6 @@ pytest = "^8.3.2"
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
pythonpath = ["."]
|
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
markers =
|
||||||
|
integration: marks tests as integration tests
|
16
runner.py
16
runner.py
@ -49,19 +49,29 @@ async def main():
|
|||||||
)
|
)
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name="Message 2",
|
name="Message 2",
|
||||||
episode_body="Paul: I own many bananas",
|
episode_body="Paul: I hate apples now",
|
||||||
source_description="WhatsApp Message",
|
source_description="WhatsApp Message",
|
||||||
)
|
)
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name="Message 3",
|
name="Message 3",
|
||||||
episode_body="Assistant: The best type of apples available are Fuji apples",
|
episode_body="Jane: I am married to Paul",
|
||||||
source_description="WhatsApp Message",
|
source_description="WhatsApp Message",
|
||||||
)
|
)
|
||||||
await client.add_episode(
|
await client.add_episode(
|
||||||
name="Message 4",
|
name="Message 4",
|
||||||
episode_body="Paul: Oh, I actually hate those",
|
episode_body="Paul: I have divorced Jane",
|
||||||
source_description="WhatsApp Message",
|
source_description="WhatsApp Message",
|
||||||
)
|
)
|
||||||
|
# await client.add_episode(
|
||||||
|
# name="Message 3",
|
||||||
|
# episode_body="Assistant: The best type of apples available are Fuji apples",
|
||||||
|
# source_description="WhatsApp Message",
|
||||||
|
# )
|
||||||
|
# await client.add_episode(
|
||||||
|
# name="Message 4",
|
||||||
|
# episode_body="Paul: Oh, I actually hate those",
|
||||||
|
# source_description="WhatsApp Message",
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
@ -3,6 +3,9 @@ import sys
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
233
tests/utils/maintenance/test_temporal_operations.py
Normal file
233
tests/utils/maintenance/test_temporal_operations.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from core.utils.maintenance.temporal_operations import (
|
||||||
|
prepare_edges_for_invalidation,
|
||||||
|
prepare_invalidation_context,
|
||||||
|
)
|
||||||
|
from core.edges import EntityEdge
|
||||||
|
from core.nodes import EntityNode
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to create test data
|
||||||
|
def create_test_data():
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
# Create nodes
|
||||||
|
node1 = EntityNode(uuid="1", name="Node1", labels=["Person"], created_at=now)
|
||||||
|
node2 = EntityNode(uuid="2", name="Node2", labels=["Person"], created_at=now)
|
||||||
|
node3 = EntityNode(uuid="3", name="Node3", labels=["Person"], created_at=now)
|
||||||
|
|
||||||
|
# Create edges
|
||||||
|
existing_edge1 = EntityEdge(
|
||||||
|
uuid="e1",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="2",
|
||||||
|
name="KNOWS",
|
||||||
|
fact="Node1 knows Node2",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
existing_edge2 = EntityEdge(
|
||||||
|
uuid="e2",
|
||||||
|
source_node_uuid="2",
|
||||||
|
target_node_uuid="3",
|
||||||
|
name="LIKES",
|
||||||
|
fact="Node2 likes Node3",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
new_edge1 = EntityEdge(
|
||||||
|
uuid="e3",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="3",
|
||||||
|
name="WORKS_WITH",
|
||||||
|
fact="Node1 works with Node3",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
new_edge2 = EntityEdge(
|
||||||
|
uuid="e4",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="2",
|
||||||
|
name="DISLIKES",
|
||||||
|
fact="Node1 dislikes Node2",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"nodes": [node1, node2, node3],
|
||||||
|
"existing_edges": [existing_edge1, existing_edge2],
|
||||||
|
"new_edges": [new_edge1, new_edge2],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_edges_for_invalidation_basic():
|
||||||
|
test_data = create_test_data()
|
||||||
|
|
||||||
|
existing_edges_pending_invalidation, new_edges_with_nodes = (
|
||||||
|
prepare_edges_for_invalidation(
|
||||||
|
test_data["existing_edges"], test_data["new_edges"], test_data["nodes"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(existing_edges_pending_invalidation) == 2
|
||||||
|
assert len(new_edges_with_nodes) == 2
|
||||||
|
|
||||||
|
# Check if the edges are correctly associated with nodes
|
||||||
|
for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes:
|
||||||
|
assert isinstance(edge_with_nodes[0], EntityNode)
|
||||||
|
assert isinstance(edge_with_nodes[1], EntityEdge)
|
||||||
|
assert isinstance(edge_with_nodes[2], EntityNode)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_edges_for_invalidation_no_existing_edges():
|
||||||
|
test_data = create_test_data()
|
||||||
|
|
||||||
|
existing_edges_pending_invalidation, new_edges_with_nodes = (
|
||||||
|
prepare_edges_for_invalidation([], test_data["new_edges"], test_data["nodes"])
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(existing_edges_pending_invalidation) == 0
|
||||||
|
assert len(new_edges_with_nodes) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_edges_for_invalidation_no_new_edges():
|
||||||
|
test_data = create_test_data()
|
||||||
|
|
||||||
|
existing_edges_pending_invalidation, new_edges_with_nodes = (
|
||||||
|
prepare_edges_for_invalidation(
|
||||||
|
test_data["existing_edges"], [], test_data["nodes"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(existing_edges_pending_invalidation) == 2
|
||||||
|
assert len(new_edges_with_nodes) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_edges_for_invalidation_missing_nodes():
|
||||||
|
test_data = create_test_data()
|
||||||
|
|
||||||
|
# Remove one node to simulate a missing node scenario
|
||||||
|
nodes = test_data["nodes"][:-1]
|
||||||
|
|
||||||
|
existing_edges_pending_invalidation, new_edges_with_nodes = (
|
||||||
|
prepare_edges_for_invalidation(
|
||||||
|
test_data["existing_edges"], test_data["new_edges"], nodes
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(existing_edges_pending_invalidation) == 1
|
||||||
|
assert len(new_edges_with_nodes) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_invalidation_context():
|
||||||
|
# Create test data
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
# Create nodes
|
||||||
|
node1 = EntityNode(uuid="1", name="Node1", labels=["Person"], created_at=now)
|
||||||
|
node2 = EntityNode(uuid="2", name="Node2", labels=["Person"], created_at=now)
|
||||||
|
node3 = EntityNode(uuid="3", name="Node3", labels=["Person"], created_at=now)
|
||||||
|
|
||||||
|
# Create edges
|
||||||
|
edge1 = EntityEdge(
|
||||||
|
uuid="e1",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="2",
|
||||||
|
name="KNOWS",
|
||||||
|
fact="Node1 knows Node2",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
edge2 = EntityEdge(
|
||||||
|
uuid="e2",
|
||||||
|
source_node_uuid="2",
|
||||||
|
target_node_uuid="3",
|
||||||
|
name="LIKES",
|
||||||
|
fact="Node2 likes Node3",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create NodeEdgeNodeTriplet objects
|
||||||
|
existing_edge = (node1, edge1, node2)
|
||||||
|
new_edge = (node2, edge2, node3)
|
||||||
|
|
||||||
|
# Prepare test input
|
||||||
|
existing_edges = [existing_edge]
|
||||||
|
new_edges = [new_edge]
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
result = prepare_invalidation_context(existing_edges, new_edges)
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "existing_edges" in result
|
||||||
|
assert "new_edges" in result
|
||||||
|
assert len(result["existing_edges"]) == 1
|
||||||
|
assert len(result["new_edges"]) == 1
|
||||||
|
|
||||||
|
# Check the format of the existing edge
|
||||||
|
existing_edge_str = result["existing_edges"][0]
|
||||||
|
assert edge1.uuid in existing_edge_str
|
||||||
|
assert node1.name in existing_edge_str
|
||||||
|
assert edge1.name in existing_edge_str
|
||||||
|
assert node2.name in existing_edge_str
|
||||||
|
assert edge1.created_at.isoformat() in existing_edge_str
|
||||||
|
|
||||||
|
# Check the format of the new edge
|
||||||
|
new_edge_str = result["new_edges"][0]
|
||||||
|
assert edge2.uuid in new_edge_str
|
||||||
|
assert node2.name in new_edge_str
|
||||||
|
assert edge2.name in new_edge_str
|
||||||
|
assert node3.name in new_edge_str
|
||||||
|
assert edge2.created_at.isoformat() in new_edge_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_invalidation_context_empty_input():
|
||||||
|
result = prepare_invalidation_context([], [])
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "existing_edges" in result
|
||||||
|
assert "new_edges" in result
|
||||||
|
assert len(result["existing_edges"]) == 0
|
||||||
|
assert len(result["new_edges"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_invalidation_context_sorting():
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
# Create nodes
|
||||||
|
node1 = EntityNode(uuid="1", name="Node1", labels=["Person"], created_at=now)
|
||||||
|
node2 = EntityNode(uuid="2", name="Node2", labels=["Person"], created_at=now)
|
||||||
|
|
||||||
|
# Create edges with different timestamps
|
||||||
|
edge1 = EntityEdge(
|
||||||
|
uuid="e1",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="2",
|
||||||
|
name="KNOWS",
|
||||||
|
fact="Node1 knows Node2",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
edge2 = EntityEdge(
|
||||||
|
uuid="e2",
|
||||||
|
source_node_uuid="2",
|
||||||
|
target_node_uuid="1",
|
||||||
|
name="LIKES",
|
||||||
|
fact="Node2 likes Node1",
|
||||||
|
created_at=now + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
edge_with_nodes1 = (node1, edge1, node2)
|
||||||
|
edge_with_nodes2 = (node2, edge2, node1)
|
||||||
|
|
||||||
|
# Prepare test input
|
||||||
|
existing_edges = [edge_with_nodes1, edge_with_nodes2]
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
result = prepare_invalidation_context(existing_edges, [])
|
||||||
|
|
||||||
|
# Assert the result
|
||||||
|
assert len(result["existing_edges"]) == 2
|
||||||
|
assert edge2.uuid in result["existing_edges"][0] # The newer edge should be first
|
||||||
|
assert edge1.uuid in result["existing_edges"][1] # The older edge should be second
|
||||||
|
|
||||||
|
|
||||||
|
# Run the tests
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
306
tests/utils/maintenance/test_temporal_operations_int.py
Normal file
306
tests/utils/maintenance/test_temporal_operations_int.py
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from core.utils.maintenance.temporal_operations import (
|
||||||
|
invalidate_edges,
|
||||||
|
)
|
||||||
|
from core.edges import EntityEdge
|
||||||
|
from core.nodes import EntityNode
|
||||||
|
from core.llm_client import OpenAIClient, LLMConfig
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_llm_client():
|
||||||
|
return OpenAIClient(
|
||||||
|
LLMConfig(
|
||||||
|
api_key=os.getenv("TEST_OPENAI_API_KEY"),
|
||||||
|
model=os.getenv("TEST_OPENAI_MODEL"),
|
||||||
|
base_url="https://api.openai.com/v1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to create test data
|
||||||
|
def create_test_data():
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
# Create nodes
|
||||||
|
node1 = EntityNode(uuid="1", name="Alice", labels=["Person"], created_at=now)
|
||||||
|
node2 = EntityNode(uuid="2", name="Bob", labels=["Person"], created_at=now)
|
||||||
|
|
||||||
|
# Create edges
|
||||||
|
edge1 = EntityEdge(
|
||||||
|
uuid="e1",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="2",
|
||||||
|
name="LIKES",
|
||||||
|
fact="Alice likes Bob",
|
||||||
|
created_at=now - timedelta(days=1),
|
||||||
|
)
|
||||||
|
edge2 = EntityEdge(
|
||||||
|
uuid="e2",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="2",
|
||||||
|
name="DISLIKES",
|
||||||
|
fact="Alice dislikes Bob",
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_edge = (node1, edge1, node2)
|
||||||
|
new_edge = (node1, edge2, node2)
|
||||||
|
|
||||||
|
return existing_edge, new_edge
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_invalidate_edges():
|
||||||
|
existing_edge, new_edge = create_test_data()
|
||||||
|
|
||||||
|
invalidated_edges = await invalidate_edges(
|
||||||
|
setup_llm_client(), [existing_edge], [new_edge]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(invalidated_edges) == 1
|
||||||
|
assert invalidated_edges[0].uuid == existing_edge[1].uuid
|
||||||
|
assert invalidated_edges[0].expired_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_invalidate_edges_no_invalidation():
|
||||||
|
existing_edge, _ = create_test_data()
|
||||||
|
|
||||||
|
invalidated_edges = await invalidate_edges(setup_llm_client(), [existing_edge], [])
|
||||||
|
|
||||||
|
assert len(invalidated_edges) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_invalidate_edges_multiple_existing():
|
||||||
|
existing_edge1, new_edge = create_test_data()
|
||||||
|
existing_edge2, _ = create_test_data()
|
||||||
|
existing_edge2[1].uuid = "e3"
|
||||||
|
existing_edge2[1].name = "KNOWS"
|
||||||
|
existing_edge2[1].fact = "Alice knows Bob"
|
||||||
|
|
||||||
|
invalidated_edges = await invalidate_edges(
|
||||||
|
setup_llm_client(), [existing_edge1, existing_edge2], [new_edge]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(invalidated_edges) == 1
|
||||||
|
assert invalidated_edges[0].uuid == existing_edge1[1].uuid
|
||||||
|
assert invalidated_edges[0].expired_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to create more complex test data
|
||||||
|
def create_complex_test_data():
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
# Create nodes
|
||||||
|
node1 = EntityNode(uuid="1", name="Alice", labels=["Person"], created_at=now)
|
||||||
|
node2 = EntityNode(uuid="2", name="Bob", labels=["Person"], created_at=now)
|
||||||
|
node3 = EntityNode(uuid="3", name="Charlie", labels=["Person"], created_at=now)
|
||||||
|
node4 = EntityNode(
|
||||||
|
uuid="4", name="Company XYZ", labels=["Organization"], created_at=now
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create edges
|
||||||
|
edge1 = EntityEdge(
|
||||||
|
uuid="e1",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="2",
|
||||||
|
name="LIKES",
|
||||||
|
fact="Alice likes Bob",
|
||||||
|
created_at=now - timedelta(days=5),
|
||||||
|
)
|
||||||
|
edge2 = EntityEdge(
|
||||||
|
uuid="e2",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="3",
|
||||||
|
name="FRIENDS_WITH",
|
||||||
|
fact="Alice is friends with Charlie",
|
||||||
|
created_at=now - timedelta(days=3),
|
||||||
|
)
|
||||||
|
edge3 = EntityEdge(
|
||||||
|
uuid="e3",
|
||||||
|
source_node_uuid="2",
|
||||||
|
target_node_uuid="4",
|
||||||
|
name="WORKS_FOR",
|
||||||
|
fact="Bob works for Company XYZ",
|
||||||
|
created_at=now - timedelta(days=2),
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_edge1 = (node1, edge1, node2)
|
||||||
|
existing_edge2 = (node1, edge2, node3)
|
||||||
|
existing_edge3 = (node2, edge3, node4)
|
||||||
|
|
||||||
|
return [existing_edge1, existing_edge2, existing_edge3], [
|
||||||
|
node1,
|
||||||
|
node2,
|
||||||
|
node3,
|
||||||
|
node4,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_invalidate_edges_complex():
|
||||||
|
existing_edges, nodes = create_complex_test_data()
|
||||||
|
|
||||||
|
# Create a new edge that contradicts an existing one
|
||||||
|
new_edge = (
|
||||||
|
nodes[0],
|
||||||
|
EntityEdge(
|
||||||
|
uuid="e4",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="2",
|
||||||
|
name="DISLIKES",
|
||||||
|
fact="Alice dislikes Bob",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
),
|
||||||
|
nodes[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
invalidated_edges = await invalidate_edges(
|
||||||
|
setup_llm_client(), existing_edges, [new_edge]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(invalidated_edges) == 1
|
||||||
|
assert invalidated_edges[0].uuid == "e1"
|
||||||
|
assert invalidated_edges[0].expired_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_invalidate_edges_temporal_update():
|
||||||
|
existing_edges, nodes = create_complex_test_data()
|
||||||
|
|
||||||
|
# Create a new edge that updates an existing one with new information
|
||||||
|
new_edge = (
|
||||||
|
nodes[1],
|
||||||
|
EntityEdge(
|
||||||
|
uuid="e5",
|
||||||
|
source_node_uuid="2",
|
||||||
|
target_node_uuid="4",
|
||||||
|
name="LEFT_JOB",
|
||||||
|
fact="Bob left his job at Company XYZ",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
),
|
||||||
|
nodes[3],
|
||||||
|
)
|
||||||
|
|
||||||
|
invalidated_edges = await invalidate_edges(
|
||||||
|
setup_llm_client(), existing_edges, [new_edge]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(invalidated_edges) == 1
|
||||||
|
assert invalidated_edges[0].uuid == "e3"
|
||||||
|
assert invalidated_edges[0].expired_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_invalidate_edges_multiple_invalidations():
|
||||||
|
existing_edges, nodes = create_complex_test_data()
|
||||||
|
|
||||||
|
# Create new edges that invalidate multiple existing edges
|
||||||
|
new_edge1 = (
|
||||||
|
nodes[0],
|
||||||
|
EntityEdge(
|
||||||
|
uuid="e6",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="2",
|
||||||
|
name="ENEMIES_WITH",
|
||||||
|
fact="Alice and Bob are now enemies",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
),
|
||||||
|
nodes[1],
|
||||||
|
)
|
||||||
|
new_edge2 = (
|
||||||
|
nodes[0],
|
||||||
|
EntityEdge(
|
||||||
|
uuid="e7",
|
||||||
|
source_node_uuid="1",
|
||||||
|
target_node_uuid="3",
|
||||||
|
name="ENDED_FRIENDSHIP",
|
||||||
|
fact="Alice ended her friendship with Charlie",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
),
|
||||||
|
nodes[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
invalidated_edges = await invalidate_edges(
|
||||||
|
setup_llm_client(), existing_edges, [new_edge1, new_edge2]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(invalidated_edges) == 2
|
||||||
|
assert set(edge.uuid for edge in invalidated_edges) == {"e1", "e2"}
|
||||||
|
for edge in invalidated_edges:
|
||||||
|
assert edge.expired_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_invalidate_edges_no_effect():
|
||||||
|
existing_edges, nodes = create_complex_test_data()
|
||||||
|
|
||||||
|
# Create a new edge that doesn't invalidate any existing edges
|
||||||
|
new_edge = (
|
||||||
|
nodes[2],
|
||||||
|
EntityEdge(
|
||||||
|
uuid="e8",
|
||||||
|
source_node_uuid="3",
|
||||||
|
target_node_uuid="4",
|
||||||
|
name="APPLIED_TO",
|
||||||
|
fact="Charlie applied to Company XYZ",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
),
|
||||||
|
nodes[3],
|
||||||
|
)
|
||||||
|
|
||||||
|
invalidated_edges = await invalidate_edges(
|
||||||
|
setup_llm_client(), existing_edges, [new_edge]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(invalidated_edges) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_invalidate_edges_partial_update():
|
||||||
|
existing_edges, nodes = create_complex_test_data()
|
||||||
|
|
||||||
|
# Create a new edge that partially updates an existing one
|
||||||
|
new_edge = (
|
||||||
|
nodes[1],
|
||||||
|
EntityEdge(
|
||||||
|
uuid="e9",
|
||||||
|
source_node_uuid="2",
|
||||||
|
target_node_uuid="4",
|
||||||
|
name="CHANGED_POSITION",
|
||||||
|
fact="Bob changed his position at Company XYZ",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
),
|
||||||
|
nodes[3],
|
||||||
|
)
|
||||||
|
|
||||||
|
invalidated_edges = await invalidate_edges(
|
||||||
|
setup_llm_client(), existing_edges, [new_edge]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(invalidated_edges) == 0
|
||||||
|
) # The existing edge is not invalidated, just updated
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.integration
|
||||||
|
async def test_invalidate_edges_empty_inputs():
|
||||||
|
invalidated_edges = await invalidate_edges(setup_llm_client(), [], [])
|
||||||
|
|
||||||
|
assert len(invalidated_edges) == 0
|
Loading…
x
Reference in New Issue
Block a user