mirror of
https://github.com/getzep/graphiti.git
synced 2025-12-27 07:03:47 +00:00
add_fact endpoint (#207)
* add_fact endpoint * bump version * add edge invalidation * update
This commit is contained in:
parent
6536401c8c
commit
3199e893ed
@ -19,7 +19,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@ -78,7 +78,7 @@ async def add_messages(client: Graphiti):
|
||||
name=f'Message {i}',
|
||||
episode_body=message,
|
||||
source=EpisodeType.message,
|
||||
reference_time=datetime.now(),
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
source_description='Shoe conversation',
|
||||
)
|
||||
|
||||
@ -105,7 +105,7 @@ async def ingest_products_data(client: Graphiti):
|
||||
content=str(product),
|
||||
source_description='Allbirds products',
|
||||
source=EpisodeType.json,
|
||||
reference_time=datetime.now(),
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
)
|
||||
for i, product in enumerate(products)
|
||||
]
|
||||
|
||||
@ -15,7 +15,7 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -45,7 +45,7 @@ def parse_msc_messages() -> list[list[ParsedMscMessage]]:
|
||||
ParsedMscMessage(
|
||||
speaker_name=speakers[speaker_idx],
|
||||
content=content,
|
||||
actual_timestamp=datetime.now(),
|
||||
actual_timestamp=datetime.now(timezone.utc),
|
||||
group_id=str(i),
|
||||
)
|
||||
)
|
||||
@ -60,7 +60,7 @@ def parse_msc_messages() -> list[list[ParsedMscMessage]]:
|
||||
ParsedMscMessage(
|
||||
speaker_name=speakers[speaker_idx],
|
||||
content=content,
|
||||
actual_timestamp=datetime.now(),
|
||||
actual_timestamp=datetime.now(timezone.utc),
|
||||
group_id=str(i),
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -61,7 +61,7 @@ def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[Par
|
||||
break
|
||||
|
||||
# Calculate the start time
|
||||
now = datetime.now()
|
||||
now = datetime.now(timezone.utc)
|
||||
podcast_start_time = now - last_timestamp
|
||||
|
||||
for message in messages:
|
||||
|
||||
@ -18,7 +18,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@ -63,7 +63,7 @@ async def main():
|
||||
messages = get_wizard_of_oz_messages()
|
||||
print(messages)
|
||||
print(len(messages))
|
||||
now = datetime.now()
|
||||
now = datetime.now(timezone.utc)
|
||||
# episodes: list[BulkEpisode] = [
|
||||
# BulkEpisode(
|
||||
# name=f'Chapter {i + 1}',
|
||||
|
||||
@ -16,7 +16,7 @@ limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from time import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@ -65,7 +65,9 @@ from graphiti_core.utils.maintenance.community_operations import (
|
||||
update_community,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
dedupe_extracted_edge,
|
||||
extract_edges,
|
||||
resolve_edge_contradictions,
|
||||
resolve_extracted_edges,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||
@ -76,6 +78,7 @@ from graphiti_core.utils.maintenance.node_operations import (
|
||||
extract_nodes,
|
||||
resolve_extracted_nodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -312,7 +315,7 @@ class Graphiti:
|
||||
start = time()
|
||||
|
||||
entity_edges: list[EntityEdge] = []
|
||||
now = datetime.now()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
previous_episodes = await self.retrieve_episodes(
|
||||
reference_time, last_n=3, group_ids=[group_id]
|
||||
@ -448,7 +451,6 @@ class Graphiti:
|
||||
|
||||
episode.entity_edges = [edge.uuid for edge in entity_edges]
|
||||
|
||||
# Future optimization would be using batch operations to save nodes and edges
|
||||
if not self.store_raw_episode_content:
|
||||
episode.content = ''
|
||||
|
||||
@ -511,7 +513,7 @@ class Graphiti:
|
||||
"""
|
||||
try:
|
||||
start = time()
|
||||
now = datetime.now()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
@ -760,3 +762,36 @@ class Graphiti:
|
||||
communities = await get_communities_by_nodes(self.driver, nodes)
|
||||
|
||||
return SearchResults(edges=edges, nodes=nodes, communities=communities)
|
||||
|
||||
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
|
||||
if source_node.name_embedding is None:
|
||||
await source_node.generate_name_embedding(self.embedder)
|
||||
if target_node.name_embedding is None:
|
||||
await target_node.generate_name_embedding(self.embedder)
|
||||
if edge.fact_embedding is None:
|
||||
await edge.generate_embedding(self.embedder)
|
||||
|
||||
resolved_nodes, _ = await resolve_extracted_nodes(
|
||||
self.llm_client,
|
||||
[source_node, target_node],
|
||||
[
|
||||
await get_relevant_nodes([source_node], self.driver),
|
||||
await get_relevant_nodes([target_node], self.driver),
|
||||
],
|
||||
)
|
||||
|
||||
related_edges = await get_relevant_edges(
|
||||
self.driver,
|
||||
[edge],
|
||||
source_node_uuid=resolved_nodes[0].uuid,
|
||||
target_node_uuid=resolved_nodes[1].uuid,
|
||||
)
|
||||
|
||||
resolved_edge = await dedupe_extracted_edge(self.llm_client, edge, related_edges)
|
||||
|
||||
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges)
|
||||
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
|
||||
|
||||
await add_nodes_and_edges_bulk(
|
||||
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
|
||||
)
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
EPISODIC_EDGE_SAVE = """
|
||||
MATCH (episode:Episodic {uuid: $episode_uuid})
|
||||
MATCH (node:Entity {uuid: $entity_uuid})
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
EPISODIC_NODE_SAVE = """
|
||||
MERGE (n:Episodic {uuid: $uuid})
|
||||
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
|
||||
|
||||
@ -16,7 +16,7 @@ limitations under the License.
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from time import time
|
||||
from typing import Any
|
||||
@ -78,7 +78,7 @@ class Node(BaseModel, ABC):
|
||||
name: str = Field(description='name of the node')
|
||||
group_id: str = Field(description='partition of the graph')
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now())
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@abstractmethod
|
||||
async def save(self, driver: AsyncDriver): ...
|
||||
|
||||
@ -18,7 +18,7 @@ import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from math import ceil
|
||||
|
||||
from neo4j import AsyncDriver, AsyncManagedTransaction
|
||||
@ -385,7 +385,7 @@ async def extract_edge_dates_bulk(
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at:
|
||||
edge.expired_at = datetime.now()
|
||||
edge.expired_at = datetime.now(timezone.utc)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel
|
||||
@ -178,7 +178,7 @@ async def build_community(
|
||||
|
||||
summary = summaries[0]
|
||||
name = await generate_summary_description(llm_client, summary)
|
||||
now = datetime.now()
|
||||
now = datetime.now(timezone.utc)
|
||||
community_node = CommunityNode(
|
||||
name=name,
|
||||
group_id=community_cluster[0].group_id,
|
||||
|
||||
@ -16,7 +16,7 @@ limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
@ -110,7 +110,7 @@ async def extract_edges(
|
||||
group_id=group_id,
|
||||
fact=edge_data['fact'],
|
||||
episodes=[episode.uuid],
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
@ -205,39 +205,9 @@ async def resolve_extracted_edges(
|
||||
return resolved_edges, invalidated_edges
|
||||
|
||||
|
||||
async def resolve_extracted_edge(
|
||||
llm_client: LLMClient,
|
||||
extracted_edge: EntityEdge,
|
||||
related_edges: list[EntityEdge],
|
||||
existing_edges: list[EntityEdge],
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> tuple[EntityEdge, list[EntityEdge]]:
|
||||
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
|
||||
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
|
||||
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
|
||||
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
||||
)
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
|
||||
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
|
||||
if invalid_at is not None and resolved_edge.expired_at is None:
|
||||
resolved_edge.expired_at = now
|
||||
|
||||
# Determine if the new_edge needs to be expired
|
||||
if resolved_edge.expired_at is None:
|
||||
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
||||
for candidate in invalidation_candidates:
|
||||
if (
|
||||
candidate.valid_at is not None and resolved_edge.valid_at is not None
|
||||
) and candidate.valid_at > resolved_edge.valid_at:
|
||||
# Expire new edge since we have information about more recent events
|
||||
resolved_edge.invalid_at = candidate.valid_at
|
||||
resolved_edge.expired_at = now
|
||||
break
|
||||
|
||||
def resolve_edge_contradictions(
|
||||
resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
|
||||
) -> list[EntityEdge]:
|
||||
# Determine which contradictory edges need to be expired
|
||||
invalidated_edges: list[EntityEdge] = []
|
||||
for edge in invalidation_candidates:
|
||||
@ -259,9 +229,50 @@ async def resolve_extracted_edge(
|
||||
and edge.valid_at < resolved_edge.valid_at
|
||||
):
|
||||
edge.invalid_at = resolved_edge.valid_at
|
||||
edge.expired_at = edge.expired_at if edge.expired_at is not None else now
|
||||
edge.expired_at = (
|
||||
edge.expired_at if edge.expired_at is not None else datetime.now(timezone.utc)
|
||||
)
|
||||
invalidated_edges.append(edge)
|
||||
|
||||
return invalidated_edges
|
||||
|
||||
|
||||
async def resolve_extracted_edge(
|
||||
llm_client: LLMClient,
|
||||
extracted_edge: EntityEdge,
|
||||
related_edges: list[EntityEdge],
|
||||
existing_edges: list[EntityEdge],
|
||||
current_episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> tuple[EntityEdge, list[EntityEdge]]:
|
||||
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
|
||||
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
|
||||
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
|
||||
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
|
||||
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
|
||||
if invalid_at is not None and resolved_edge.expired_at is None:
|
||||
resolved_edge.expired_at = now
|
||||
|
||||
# Determine if the new_edge needs to be expired
|
||||
if resolved_edge.expired_at is None:
|
||||
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
||||
for candidate in invalidation_candidates:
|
||||
if (
|
||||
candidate.valid_at is not None and resolved_edge.valid_at is not None
|
||||
) and candidate.valid_at > resolved_edge.valid_at:
|
||||
# Expire new edge since we have information about more recent events
|
||||
resolved_edge.invalid_at = candidate.valid_at
|
||||
resolved_edge.expired_at = now
|
||||
break
|
||||
|
||||
# Determine which contradictory edges need to be expired
|
||||
invalidated_edges = resolve_edge_contradictions(resolved_edge, invalidation_candidates)
|
||||
|
||||
return resolved_edge, invalidated_edges
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
||||
@ -113,7 +113,7 @@ async def extract_nodes(
|
||||
group_id=episode.group_id,
|
||||
labels=node_data['labels'],
|
||||
summary=node_data['summary'],
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
new_nodes.append(new_node)
|
||||
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "graphiti-core"
|
||||
version = "0.3.21"
|
||||
version = "0.4.0"
|
||||
description = "A temporal graph building library"
|
||||
authors = [
|
||||
"Paul Paliychuk <paul@getzep.com>",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, status
|
||||
|
||||
@ -36,7 +36,7 @@ async def get_entity_edge(uuid: str, graphiti: ZepGraphitiDep):
|
||||
@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK)
|
||||
async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep):
|
||||
episodes = await graphiti.retrieve_episodes(
|
||||
group_ids=[group_id], last_n=last_n, reference_time=datetime.now()
|
||||
group_ids=[group_id], last_n=last_n, reference_time=datetime.now(timezone.utc)
|
||||
)
|
||||
return episodes
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
@ -66,7 +66,39 @@ def setup_logging():
|
||||
async def test_graphiti_init():
|
||||
logger = setup_logging()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
|
||||
episodes = await graphiti.retrieve_episodes(datetime.now(), group_ids=None)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
alice_node = EntityNode(
|
||||
name='Alice',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
summary='Alice summary',
|
||||
group_id='test',
|
||||
)
|
||||
|
||||
bob_node = EntityNode(
|
||||
name='Bob',
|
||||
labels=[],
|
||||
created_at=now,
|
||||
summary='Bob summary',
|
||||
group_id='test',
|
||||
)
|
||||
|
||||
entity_edge = EntityEdge(
|
||||
source_node_uuid=alice_node.uuid,
|
||||
target_node_uuid=bob_node.uuid,
|
||||
created_at=now,
|
||||
name='likes',
|
||||
fact='Alice likes Bob',
|
||||
episodes=[],
|
||||
expired_at=now,
|
||||
valid_at=now,
|
||||
group_id='test',
|
||||
)
|
||||
|
||||
await graphiti.add_triplet(alice_node, entity_edge, bob_node)
|
||||
|
||||
episodes = await graphiti.retrieve_episodes(datetime.now(timezone.utc), group_ids=None)
|
||||
episode_uuids = [episode.uuid for episode in episodes]
|
||||
|
||||
results = await graphiti._search(
|
||||
@ -92,7 +124,7 @@ async def test_graph_integration():
|
||||
embedder = client.embedder
|
||||
driver = client.driver
|
||||
|
||||
now = datetime.now()
|
||||
now = datetime.now(timezone.utc)
|
||||
episode = EpisodicNode(
|
||||
name='test_episode',
|
||||
labels=[],
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
@ -23,7 +23,7 @@ def mock_extracted_edge():
|
||||
group_id='group_1',
|
||||
fact='Test fact',
|
||||
episodes=['episode_1'],
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
@ -39,8 +39,8 @@ def mock_related_edges():
|
||||
group_id='group_1',
|
||||
fact='Related fact',
|
||||
episodes=['episode_2'],
|
||||
created_at=datetime.now() - timedelta(days=1),
|
||||
valid_at=datetime.now() - timedelta(days=1),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
invalid_at=None,
|
||||
)
|
||||
]
|
||||
@ -56,8 +56,8 @@ def mock_existing_edges():
|
||||
group_id='group_1',
|
||||
fact='Existing fact',
|
||||
episodes=['episode_3'],
|
||||
created_at=datetime.now() - timedelta(days=2),
|
||||
valid_at=datetime.now() - timedelta(days=2),
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||
invalid_at=None,
|
||||
)
|
||||
]
|
||||
@ -68,7 +68,7 @@ def mock_current_episode():
|
||||
return EpisodicNode(
|
||||
uuid='episode_1',
|
||||
content='Current episode content',
|
||||
valid_at=datetime.now(),
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
name='Current Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
@ -82,7 +82,7 @@ def mock_previous_episodes():
|
||||
EpisodicNode(
|
||||
uuid='episode_2',
|
||||
content='Previous episode content',
|
||||
valid_at=datetime.now() - timedelta(days=1),
|
||||
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
name='Previous Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
@ -144,8 +144,8 @@ async def test_resolve_extracted_edge_with_dates(
|
||||
mock_previous_episodes,
|
||||
monkeypatch: MonkeyPatch,
|
||||
):
|
||||
valid_at = datetime.now() - timedelta(days=1)
|
||||
invalid_at = datetime.now() + timedelta(days=1)
|
||||
valid_at = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
invalid_at = datetime.now(timezone.utc) + timedelta(days=1)
|
||||
|
||||
# Mock the function calls
|
||||
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
|
||||
@ -189,7 +189,7 @@ async def test_resolve_extracted_edge_with_invalidation(
|
||||
mock_previous_episodes,
|
||||
monkeypatch: MonkeyPatch,
|
||||
):
|
||||
valid_at = datetime.now() - timedelta(days=1)
|
||||
valid_at = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
mock_extracted_edge.valid_at = valid_at
|
||||
|
||||
invalidation_candidate = EntityEdge(
|
||||
@ -199,8 +199,8 @@ async def test_resolve_extracted_edge_with_invalidation(
|
||||
group_id='group_1',
|
||||
fact='Invalidation candidate fact',
|
||||
episodes=['episode_4'],
|
||||
created_at=datetime.now(),
|
||||
valid_at=datetime.now() - timedelta(days=2),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
|
||||
@ -15,11 +15,10 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from pytz import UTC
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||
@ -43,7 +42,7 @@ def setup_llm_client():
|
||||
|
||||
|
||||
def create_test_data():
|
||||
now = datetime.now()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create edges
|
||||
existing_edge = EntityEdge(
|
||||
@ -132,7 +131,7 @@ async def test_get_edge_contradictions_multiple_existing():
|
||||
|
||||
# Helper function to create more complex test data
|
||||
def create_complex_test_data():
|
||||
now = datetime.now()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
|
||||
@ -192,7 +191,7 @@ async def test_invalidate_edges_complex():
|
||||
name='DISLIKES',
|
||||
fact='Alice dislikes Bob',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
@ -214,7 +213,7 @@ async def test_get_edge_contradictions_temporal_update():
|
||||
name='LEFT_JOB',
|
||||
fact='Bob no longer works at at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
@ -236,7 +235,7 @@ async def test_get_edge_contradictions_no_effect():
|
||||
name='APPLIED_TO',
|
||||
fact='Charlie applied to Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
@ -257,7 +256,7 @@ async def test_invalidate_edges_partial_update():
|
||||
name='CHANGED_POSITION',
|
||||
fact='Bob changed his position at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
@ -266,7 +265,7 @@ async def test_invalidate_edges_partial_update():
|
||||
|
||||
|
||||
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
|
||||
now = datetime.now(UTC)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
@ -315,7 +314,7 @@ async def test_extract_edge_dates():
|
||||
name='LEFT_JOB',
|
||||
fact='Bob no longer works at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(UTC),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
valid_at, invalid_at = await extract_edge_dates(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user