refactor: use utc_now() for consistent UTC datetime handling (#234)

* ensure utc timezones

* fix: dep cycle

---------

Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
This commit is contained in:
Daniel Chalef 2024-12-09 10:36:04 -08:00 committed by GitHub
parent 732b2f328d
commit 445dccc021
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 97 additions and 60 deletions

View File

@ -16,7 +16,7 @@ limitations under the License.
import asyncio
import logging
from datetime import datetime, timezone
from datetime import datetime
from time import time
from dotenv import load_dotenv
@ -43,10 +43,6 @@ from graphiti_core.search.search_utils import (
get_relevant_edges,
get_relevant_nodes,
)
from graphiti_core.utils import (
build_episodic_edges,
retrieve_episodes,
)
from graphiti_core.utils.bulk_utils import (
RawEpisode,
add_nodes_and_edges_bulk,
@ -57,12 +53,14 @@ from graphiti_core.utils.bulk_utils import (
resolve_edge_pointers,
retrieve_previous_episodes_bulk,
)
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.community_operations import (
build_communities,
remove_communities,
update_community,
)
from graphiti_core.utils.maintenance.edge_operations import (
build_episodic_edges,
dedupe_extracted_edge,
extract_edges,
resolve_edge_contradictions,
@ -71,6 +69,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
@ -313,7 +312,7 @@ class Graphiti:
start = time()
entity_edges: list[EntityEdge] = []
now = datetime.now(timezone.utc)
now = utc_now()
previous_episodes = await self.retrieve_episodes(
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
@ -522,7 +521,7 @@ class Graphiti:
"""
try:
start = time()
now = datetime.now(timezone.utc)
now = utc_now()
episodes = [
EpisodicNode(

View File

@ -16,7 +16,7 @@ limitations under the License.
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from datetime import datetime
from enum import Enum
from time import time
from typing import Any
@ -34,6 +34,7 @@ from graphiti_core.models.nodes.node_db_queries import (
ENTITY_NODE_SAVE,
EPISODIC_NODE_SAVE,
)
from graphiti_core.utils.datetime_utils import utc_now
logger = logging.getLogger(__name__)
@ -79,7 +80,7 @@ class Node(BaseModel, ABC):
name: str = Field(description='name of the node')
group_id: str = Field(description='partition of the graph')
labels: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
created_at: datetime = Field(default_factory=lambda: utc_now())
@abstractmethod
async def save(self, driver: AsyncDriver): ...

View File

@ -1,15 +0,0 @@
from .maintenance import (
build_episodic_edges,
clear_data,
extract_edges,
extract_nodes,
retrieve_episodes,
)
__all__ = [
'extract_edges',
'build_episodic_edges',
'extract_nodes',
'clear_data',
'retrieve_episodes',
]

View File

@ -18,7 +18,7 @@ import asyncio
import logging
import typing
from collections import defaultdict
from datetime import datetime, timezone
from datetime import datetime
from math import ceil
from neo4j import AsyncDriver, AsyncManagedTransaction
@ -37,14 +37,17 @@ from graphiti_core.models.nodes.node_db_queries import (
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
from graphiti_core.utils import retrieve_episodes
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.edge_operations import (
build_episodic_edges,
dedupe_edge_list,
dedupe_extracted_edges,
extract_edges,
)
from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
dedupe_extracted_nodes,
dedupe_node_list,
@ -385,7 +388,7 @@ async def extract_edge_dates_bulk(
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = datetime.now(timezone.utc)
edge.expired_at = utc_now()
return edges

View File

@ -0,0 +1,42 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime, timezone
def utc_now() -> datetime:
"""Returns the current UTC datetime with timezone information."""
return datetime.now(timezone.utc)
def ensure_utc(dt: datetime | None) -> datetime | None:
"""
Ensures a datetime is timezone-aware and in UTC.
If the datetime is naive (no timezone), assumes it's in UTC.
If the datetime has a different timezone, converts it to UTC.
Returns None if input is None.
"""
if dt is None:
return None
if dt.tzinfo is None:
# If datetime is naive, assume it's UTC
return dt.replace(tzinfo=timezone.utc)
elif dt.tzinfo != timezone.utc:
# If datetime has a different timezone, convert to UTC
return dt.astimezone(timezone.utc)
return dt

View File

@ -1,7 +1,6 @@
import asyncio
import logging
from collections import defaultdict
from datetime import datetime, timezone
from neo4j import AsyncDriver
from pydantic import BaseModel
@ -17,6 +16,7 @@ from graphiti_core.nodes import (
)
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
MAX_COMMUNITY_BUILD_CONCURRENCY = 10
@ -180,7 +180,7 @@ async def build_community(
summary = summaries[0]
name = await generate_summary_description(llm_client, summary)
now = datetime.now(timezone.utc)
now = utc_now()
community_node = CommunityNode(
name=name,
group_id=community_cluster[0].group_id,
@ -307,7 +307,7 @@ async def update_community(
community.name = new_name
if is_new:
community_edge = (build_community_edges([entity], community, datetime.now(timezone.utc)))[0]
community_edge = (build_community_edges([entity], community, utc_now()))[0]
await community_edge.save(driver)
await community.generate_name_embedding(embedder)

View File

@ -16,7 +16,7 @@ limitations under the License.
import asyncio
import logging
from datetime import datetime, timezone
from datetime import datetime
from time import time
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
@ -26,6 +26,7 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
get_edge_contradictions,
@ -132,7 +133,7 @@ async def extract_edges(
group_id=group_id,
fact=edge_data.get('fact', ''),
episodes=[episode.uuid],
created_at=datetime.now(timezone.utc),
created_at=utc_now(),
valid_at=None,
invalid_at=None,
)
@ -251,9 +252,7 @@ def resolve_edge_contradictions(
and edge.valid_at < resolved_edge.valid_at
):
edge.invalid_at = resolved_edge.valid_at
edge.expired_at = (
edge.expired_at if edge.expired_at is not None else datetime.now(timezone.utc)
)
edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
invalidated_edges.append(edge)
return invalidated_edges
@ -273,11 +272,12 @@ async def resolve_extracted_edge(
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
)
now = datetime.now(timezone.utc)
now = utc_now()
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
if invalid_at is not None and resolved_edge.expired_at is None:
resolved_edge.valid_at = valid_at if valid_at else resolved_edge.valid_at
resolved_edge.invalid_at = invalid_at if invalid_at else resolved_edge.invalid_at
if invalid_at and not resolved_edge.expired_at:
resolved_edge.expired_at = now
# Determine if the new_edge needs to be expired
@ -285,8 +285,12 @@ async def resolve_extracted_edge(
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
for candidate in invalidation_candidates:
if (
candidate.valid_at is not None and resolved_edge.valid_at is not None
) and candidate.valid_at > resolved_edge.valid_at:
candidate.valid_at
and resolved_edge.valid_at
and candidate.valid_at.tzinfo
and resolved_edge.valid_at.tzinfo
and candidate.valid_at > resolved_edge.valid_at
):
# Expire new edge since we have information about more recent events
resolved_edge.invalid_at = candidate.valid_at
resolved_edge.expired_at = now

View File

@ -16,7 +16,6 @@ limitations under the License.
import asyncio
import logging
from datetime import datetime, timezone
from time import time
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
@ -26,6 +25,7 @@ from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities
from graphiti_core.prompts.summarize_nodes import Summary
from graphiti_core.utils.datetime_utils import utc_now
logger = logging.getLogger(__name__)
@ -155,7 +155,7 @@ async def extract_nodes(
group_id=episode.group_id,
labels=['Entity'],
summary='',
created_at=datetime.now(timezone.utc),
created_at=utc_now(),
)
new_nodes.append(new_node)
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')

View File

@ -9,7 +9,7 @@ You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
@ -24,6 +24,7 @@ from graphiti_core.nodes import EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.extract_edge_dates import EdgeDates
from graphiti_core.prompts.invalidate_edges import InvalidatedEdges
from graphiti_core.utils.datetime_utils import ensure_utc
logger = logging.getLogger(__name__)
@ -52,13 +53,15 @@ async def extract_edge_dates(
if valid_at:
try:
valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
valid_at_datetime = ensure_utc(datetime.fromisoformat(valid_at.replace('Z', '+00:00')))
except ValueError as e:
logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
if invalid_at:
try:
invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
invalid_at_datetime = ensure_utc(
datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
)
except ValueError as e:
logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')

View File

@ -1,6 +1,7 @@
from datetime import datetime
from typing import Literal
from graphiti_core.utils.datetime_utils import utc_now
from pydantic import BaseModel, Field
@ -21,9 +22,7 @@ class Message(BaseModel):
role: str | None = Field(
description='The custom role of the message to be used alongside role_type (user name, bot name, etc.)',
)
timestamp: datetime = Field(
default_factory=datetime.now, description='The timestamp of the message'
)
timestamp: datetime = Field(default_factory=utc_now, description='The timestamp of the message')
source_description: str = Field(
default='', description='The description of the source of the message'
)

View File

@ -4,7 +4,7 @@ from functools import partial
from fastapi import APIRouter, FastAPI, status
from graphiti_core.nodes import EpisodeType # type: ignore
from graphiti_core.utils import clear_data # type: ignore
from graphiti_core.utils.maintenance.graph_data_operations import clear_data # type: ignore
from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result
from graph_service.zep_graphiti import ZepGraphitiDep

View File

@ -15,7 +15,7 @@ limitations under the License.
"""
import os
from datetime import datetime, timedelta, timezone
from datetime import timedelta
import pytest
from dotenv import load_dotenv
@ -23,6 +23,7 @@ from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
get_edge_contradictions,
@ -42,7 +43,7 @@ def setup_llm_client():
def create_test_data():
now = datetime.now(timezone.utc)
now = utc_now()
# Create edges
existing_edge = EntityEdge(
@ -131,7 +132,7 @@ async def test_get_edge_contradictions_multiple_existing():
# Helper function to create more complex test data
def create_complex_test_data():
now = datetime.now(timezone.utc)
now = utc_now()
# Create nodes
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
@ -191,7 +192,7 @@ async def test_invalidate_edges_complex():
name='DISLIKES',
fact='Alice dislikes Bob',
group_id='1',
created_at=datetime.now(timezone.utc),
created_at=utc_now(),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -213,7 +214,7 @@ async def test_get_edge_contradictions_temporal_update():
name='LEFT_JOB',
fact='Bob no longer works at at Company XYZ',
group_id='1',
created_at=datetime.now(timezone.utc),
created_at=utc_now(),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -235,7 +236,7 @@ async def test_get_edge_contradictions_no_effect():
name='APPLIED_TO',
fact='Charlie applied to Company XYZ',
group_id='1',
created_at=datetime.now(timezone.utc),
created_at=utc_now(),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -256,7 +257,7 @@ async def test_invalidate_edges_partial_update():
name='CHANGED_POSITION',
fact='Bob changed his position at Company XYZ',
group_id='1',
created_at=datetime.now(timezone.utc),
created_at=utc_now(),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -265,7 +266,7 @@ async def test_invalidate_edges_partial_update():
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
now = datetime.now(timezone.utc)
now = utc_now()
previous_episodes = [
EpisodicNode(
@ -314,7 +315,7 @@ async def test_extract_edge_dates():
name='LEFT_JOB',
fact='Bob no longer works at Company XYZ',
group_id='1',
created_at=datetime.now(timezone.utc),
created_at=utc_now(),
)
valid_at, invalid_at = await extract_edge_dates(