mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
refactor: use utc_now()
for consistent UTC datetime handling (#234)
* ensure utc timezones * fix: dep cycle --------- Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
This commit is contained in:
parent
732b2f328d
commit
445dccc021
@ -16,7 +16,7 @@ limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@ -43,10 +43,6 @@ from graphiti_core.search.search_utils import (
|
||||
get_relevant_edges,
|
||||
get_relevant_nodes,
|
||||
)
|
||||
from graphiti_core.utils import (
|
||||
build_episodic_edges,
|
||||
retrieve_episodes,
|
||||
)
|
||||
from graphiti_core.utils.bulk_utils import (
|
||||
RawEpisode,
|
||||
add_nodes_and_edges_bulk,
|
||||
@ -57,12 +53,14 @@ from graphiti_core.utils.bulk_utils import (
|
||||
resolve_edge_pointers,
|
||||
retrieve_previous_episodes_bulk,
|
||||
)
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.utils.maintenance.community_operations import (
|
||||
build_communities,
|
||||
remove_communities,
|
||||
update_community,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
build_episodic_edges,
|
||||
dedupe_extracted_edge,
|
||||
extract_edges,
|
||||
resolve_edge_contradictions,
|
||||
@ -71,6 +69,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||
EPISODE_WINDOW_LEN,
|
||||
build_indices_and_constraints,
|
||||
retrieve_episodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
extract_nodes,
|
||||
@ -313,7 +312,7 @@ class Graphiti:
|
||||
start = time()
|
||||
|
||||
entity_edges: list[EntityEdge] = []
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
|
||||
previous_episodes = await self.retrieve_episodes(
|
||||
reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
|
||||
@ -522,7 +521,7 @@ class Graphiti:
|
||||
"""
|
||||
try:
|
||||
start = time()
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
|
||||
episodes = [
|
||||
EpisodicNode(
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from time import time
|
||||
from typing import Any
|
||||
@ -34,6 +34,7 @@ from graphiti_core.models.nodes.node_db_queries import (
|
||||
ENTITY_NODE_SAVE,
|
||||
EPISODIC_NODE_SAVE,
|
||||
)
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -79,7 +80,7 @@ class Node(BaseModel, ABC):
|
||||
name: str = Field(description='name of the node')
|
||||
group_id: str = Field(description='partition of the graph')
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
created_at: datetime = Field(default_factory=lambda: utc_now())
|
||||
|
||||
@abstractmethod
|
||||
async def save(self, driver: AsyncDriver): ...
|
||||
|
@ -1,15 +0,0 @@
|
||||
from .maintenance import (
|
||||
build_episodic_edges,
|
||||
clear_data,
|
||||
extract_edges,
|
||||
extract_nodes,
|
||||
retrieve_episodes,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'extract_edges',
|
||||
'build_episodic_edges',
|
||||
'extract_nodes',
|
||||
'clear_data',
|
||||
'retrieve_episodes',
|
||||
]
|
@ -18,7 +18,7 @@ import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from math import ceil
|
||||
|
||||
from neo4j import AsyncDriver, AsyncManagedTransaction
|
||||
@ -37,14 +37,17 @@ from graphiti_core.models.nodes.node_db_queries import (
|
||||
)
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
||||
from graphiti_core.utils import retrieve_episodes
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
build_episodic_edges,
|
||||
dedupe_edge_list,
|
||||
dedupe_extracted_edges,
|
||||
extract_edges,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||||
EPISODE_WINDOW_LEN,
|
||||
retrieve_episodes,
|
||||
)
|
||||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
dedupe_extracted_nodes,
|
||||
dedupe_node_list,
|
||||
@ -385,7 +388,7 @@ async def extract_edge_dates_bulk(
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at:
|
||||
edge.expired_at = datetime.now(timezone.utc)
|
||||
edge.expired_at = utc_now()
|
||||
|
||||
return edges
|
||||
|
||||
|
42
graphiti_core/utils/datetime_utils.py
Normal file
42
graphiti_core/utils/datetime_utils.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""Returns the current UTC datetime with timezone information."""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def ensure_utc(dt: datetime | None) -> datetime | None:
|
||||
"""
|
||||
Ensures a datetime is timezone-aware and in UTC.
|
||||
If the datetime is naive (no timezone), assumes it's in UTC.
|
||||
If the datetime has a different timezone, converts it to UTC.
|
||||
Returns None if input is None.
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if dt.tzinfo is None:
|
||||
# If datetime is naive, assume it's UTC
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
elif dt.tzinfo != timezone.utc:
|
||||
# If datetime has a different timezone, convert to UTC
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
return dt
|
@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
from pydantic import BaseModel
|
||||
@ -17,6 +16,7 @@ from graphiti_core.nodes import (
|
||||
)
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.utils.maintenance.edge_operations import build_community_edges
|
||||
|
||||
MAX_COMMUNITY_BUILD_CONCURRENCY = 10
|
||||
@ -180,7 +180,7 @@ async def build_community(
|
||||
|
||||
summary = summaries[0]
|
||||
name = await generate_summary_description(llm_client, summary)
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
community_node = CommunityNode(
|
||||
name=name,
|
||||
group_id=community_cluster[0].group_id,
|
||||
@ -307,7 +307,7 @@ async def update_community(
|
||||
community.name = new_name
|
||||
|
||||
if is_new:
|
||||
community_edge = (build_community_edges([entity], community, datetime.now(timezone.utc)))[0]
|
||||
community_edge = (build_community_edges([entity], community, utc_now()))[0]
|
||||
await community_edge.save(driver)
|
||||
|
||||
await community.generate_name_embedding(embedder)
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
|
||||
@ -26,6 +26,7 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate, UniqueFacts
|
||||
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
extract_edge_dates,
|
||||
get_edge_contradictions,
|
||||
@ -132,7 +133,7 @@ async def extract_edges(
|
||||
group_id=group_id,
|
||||
fact=edge_data.get('fact', ''),
|
||||
episodes=[episode.uuid],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
created_at=utc_now(),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
@ -251,9 +252,7 @@ def resolve_edge_contradictions(
|
||||
and edge.valid_at < resolved_edge.valid_at
|
||||
):
|
||||
edge.invalid_at = resolved_edge.valid_at
|
||||
edge.expired_at = (
|
||||
edge.expired_at if edge.expired_at is not None else datetime.now(timezone.utc)
|
||||
)
|
||||
edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
|
||||
invalidated_edges.append(edge)
|
||||
|
||||
return invalidated_edges
|
||||
@ -273,11 +272,12 @@ async def resolve_extracted_edge(
|
||||
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
|
||||
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
|
||||
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
|
||||
if invalid_at is not None and resolved_edge.expired_at is None:
|
||||
resolved_edge.valid_at = valid_at if valid_at else resolved_edge.valid_at
|
||||
resolved_edge.invalid_at = invalid_at if invalid_at else resolved_edge.invalid_at
|
||||
|
||||
if invalid_at and not resolved_edge.expired_at:
|
||||
resolved_edge.expired_at = now
|
||||
|
||||
# Determine if the new_edge needs to be expired
|
||||
@ -285,8 +285,12 @@ async def resolve_extracted_edge(
|
||||
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
|
||||
for candidate in invalidation_candidates:
|
||||
if (
|
||||
candidate.valid_at is not None and resolved_edge.valid_at is not None
|
||||
) and candidate.valid_at > resolved_edge.valid_at:
|
||||
candidate.valid_at
|
||||
and resolved_edge.valid_at
|
||||
and candidate.valid_at.tzinfo
|
||||
and resolved_edge.valid_at.tzinfo
|
||||
and candidate.valid_at > resolved_edge.valid_at
|
||||
):
|
||||
# Expire new edge since we have information about more recent events
|
||||
resolved_edge.invalid_at = candidate.valid_at
|
||||
resolved_edge.expired_at = now
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from time import time
|
||||
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS
|
||||
@ -26,6 +25,7 @@ from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||
from graphiti_core.prompts.extract_nodes import ExtractedNodes, MissedEntities
|
||||
from graphiti_core.prompts.summarize_nodes import Summary
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -155,7 +155,7 @@ async def extract_nodes(
|
||||
group_id=episode.group_id,
|
||||
labels=['Entity'],
|
||||
summary='',
|
||||
created_at=datetime.now(timezone.utc),
|
||||
created_at=utc_now(),
|
||||
)
|
||||
new_nodes.append(new_node)
|
||||
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
|
||||
|
@ -9,7 +9,7 @@ You may obtain a copy of the License at
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
@ -24,6 +24,7 @@ from graphiti_core.nodes import EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.extract_edge_dates import EdgeDates
|
||||
from graphiti_core.prompts.invalidate_edges import InvalidatedEdges
|
||||
from graphiti_core.utils.datetime_utils import ensure_utc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -52,13 +53,15 @@ async def extract_edge_dates(
|
||||
|
||||
if valid_at:
|
||||
try:
|
||||
valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
|
||||
valid_at_datetime = ensure_utc(datetime.fromisoformat(valid_at.replace('Z', '+00:00')))
|
||||
except ValueError as e:
|
||||
logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}')
|
||||
|
||||
if invalid_at:
|
||||
try:
|
||||
invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
|
||||
invalid_at_datetime = ensure_utc(
|
||||
datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}')
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@ -21,9 +22,7 @@ class Message(BaseModel):
|
||||
role: str | None = Field(
|
||||
description='The custom role of the message to be used alongside role_type (user name, bot name, etc.)',
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
default_factory=datetime.now, description='The timestamp of the message'
|
||||
)
|
||||
timestamp: datetime = Field(default_factory=utc_now, description='The timestamp of the message')
|
||||
source_description: str = Field(
|
||||
default='', description='The description of the source of the message'
|
||||
)
|
||||
|
@ -4,7 +4,7 @@ from functools import partial
|
||||
|
||||
from fastapi import APIRouter, FastAPI, status
|
||||
from graphiti_core.nodes import EpisodeType # type: ignore
|
||||
from graphiti_core.utils import clear_data # type: ignore
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data # type: ignore
|
||||
|
||||
from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result
|
||||
from graph_service.zep_graphiti import ZepGraphitiDep
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
@ -23,6 +23,7 @@ from dotenv import load_dotenv
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
extract_edge_dates,
|
||||
get_edge_contradictions,
|
||||
@ -42,7 +43,7 @@ def setup_llm_client():
|
||||
|
||||
|
||||
def create_test_data():
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
|
||||
# Create edges
|
||||
existing_edge = EntityEdge(
|
||||
@ -131,7 +132,7 @@ async def test_get_edge_contradictions_multiple_existing():
|
||||
|
||||
# Helper function to create more complex test data
|
||||
def create_complex_test_data():
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
|
||||
# Create nodes
|
||||
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
|
||||
@ -191,7 +192,7 @@ async def test_invalidate_edges_complex():
|
||||
name='DISLIKES',
|
||||
fact='Alice dislikes Bob',
|
||||
group_id='1',
|
||||
created_at=datetime.now(timezone.utc),
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
@ -213,7 +214,7 @@ async def test_get_edge_contradictions_temporal_update():
|
||||
name='LEFT_JOB',
|
||||
fact='Bob no longer works at at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(timezone.utc),
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
@ -235,7 +236,7 @@ async def test_get_edge_contradictions_no_effect():
|
||||
name='APPLIED_TO',
|
||||
fact='Charlie applied to Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(timezone.utc),
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
@ -256,7 +257,7 @@ async def test_invalidate_edges_partial_update():
|
||||
name='CHANGED_POSITION',
|
||||
fact='Bob changed his position at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(timezone.utc),
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
|
||||
@ -265,7 +266,7 @@ async def test_invalidate_edges_partial_update():
|
||||
|
||||
|
||||
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
|
||||
now = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
|
||||
previous_episodes = [
|
||||
EpisodicNode(
|
||||
@ -314,7 +315,7 @@ async def test_extract_edge_dates():
|
||||
name='LEFT_JOB',
|
||||
fact='Bob no longer works at Company XYZ',
|
||||
group_id='1',
|
||||
created_at=datetime.now(timezone.utc),
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
valid_at, invalid_at = await extract_edge_dates(
|
||||
|
Loading…
x
Reference in New Issue
Block a user