add_fact endpoint (#207)

* add_fact endpoint

* bump version

* add edge invalidation

* update
This commit is contained in:
Preston Rasmussen 2024-11-06 09:12:21 -05:00 committed by GitHub
parent 6536401c8c
commit 3199e893ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 196 additions and 87 deletions

View File

@ -19,7 +19,7 @@ import json
import logging
import os
import sys
from datetime import datetime
from datetime import datetime, timezone
from pathlib import Path
from dotenv import load_dotenv
@ -78,7 +78,7 @@ async def add_messages(client: Graphiti):
name=f'Message {i}',
episode_body=message,
source=EpisodeType.message,
reference_time=datetime.now(),
reference_time=datetime.now(timezone.utc),
source_description='Shoe conversation',
)
@ -105,7 +105,7 @@ async def ingest_products_data(client: Graphiti):
content=str(product),
source_description='Allbirds products',
source=EpisodeType.json,
reference_time=datetime.now(),
reference_time=datetime.now(timezone.utc),
)
for i, product in enumerate(products)
]

View File

@ -15,7 +15,7 @@ limitations under the License.
"""
import json
from datetime import datetime
from datetime import datetime, timezone
from pydantic import BaseModel
@ -45,7 +45,7 @@ def parse_msc_messages() -> list[list[ParsedMscMessage]]:
ParsedMscMessage(
speaker_name=speakers[speaker_idx],
content=content,
actual_timestamp=datetime.now(),
actual_timestamp=datetime.now(timezone.utc),
group_id=str(i),
)
)
@ -60,7 +60,7 @@ def parse_msc_messages() -> list[list[ParsedMscMessage]]:
ParsedMscMessage(
speaker_name=speakers[speaker_idx],
content=content,
actual_timestamp=datetime.now(),
actual_timestamp=datetime.now(timezone.utc),
group_id=str(i),
)
)

View File

@ -1,6 +1,6 @@
import os
import re
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import List
from pydantic import BaseModel
@ -61,7 +61,7 @@ def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[Par
break
# Calculate the start time
now = datetime.now()
now = datetime.now(timezone.utc)
podcast_start_time = now - last_timestamp
for message in messages:

View File

@ -18,7 +18,7 @@ import asyncio
import logging
import os
import sys
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from dotenv import load_dotenv
@ -63,7 +63,7 @@ async def main():
messages = get_wizard_of_oz_messages()
print(messages)
print(len(messages))
now = datetime.now()
now = datetime.now(timezone.utc)
# episodes: list[BulkEpisode] = [
# BulkEpisode(
# name=f'Chapter {i + 1}',

View File

@ -16,7 +16,7 @@ limitations under the License.
import asyncio
import logging
from datetime import datetime
from datetime import datetime, timezone
from time import time
from dotenv import load_dotenv
@ -65,7 +65,9 @@ from graphiti_core.utils.maintenance.community_operations import (
update_community,
)
from graphiti_core.utils.maintenance.edge_operations import (
dedupe_extracted_edge,
extract_edges,
resolve_edge_contradictions,
resolve_extracted_edges,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
@ -76,6 +78,7 @@ from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions
logger = logging.getLogger(__name__)
@ -312,7 +315,7 @@ class Graphiti:
start = time()
entity_edges: list[EntityEdge] = []
now = datetime.now()
now = datetime.now(timezone.utc)
previous_episodes = await self.retrieve_episodes(
reference_time, last_n=3, group_ids=[group_id]
@ -448,7 +451,6 @@ class Graphiti:
episode.entity_edges = [edge.uuid for edge in entity_edges]
# Future optimization would be using batch operations to save nodes and edges
if not self.store_raw_episode_content:
episode.content = ''
@ -511,7 +513,7 @@ class Graphiti:
"""
try:
start = time()
now = datetime.now()
now = datetime.now(timezone.utc)
episodes = [
EpisodicNode(
@ -760,3 +762,36 @@ class Graphiti:
communities = await get_communities_by_nodes(self.driver, nodes)
return SearchResults(edges=edges, nodes=nodes, communities=communities)
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
if source_node.name_embedding is None:
await source_node.generate_name_embedding(self.embedder)
if target_node.name_embedding is None:
await target_node.generate_name_embedding(self.embedder)
if edge.fact_embedding is None:
await edge.generate_embedding(self.embedder)
resolved_nodes, _ = await resolve_extracted_nodes(
self.llm_client,
[source_node, target_node],
[
await get_relevant_nodes([source_node], self.driver),
await get_relevant_nodes([target_node], self.driver),
],
)
related_edges = await get_relevant_edges(
self.driver,
[edge],
source_node_uuid=resolved_nodes[0].uuid,
target_node_uuid=resolved_nodes[1].uuid,
)
resolved_edge = await dedupe_extracted_edge(self.llm_client, edge, related_edges)
contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges)
invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges)
await add_nodes_and_edges_bulk(
self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges
)

View File

@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
EPISODIC_EDGE_SAVE = """
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})

View File

@ -1,3 +1,19 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
EPISODIC_NODE_SAVE = """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,

View File

@ -16,7 +16,7 @@ limitations under the License.
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from datetime import datetime, timezone
from enum import Enum
from time import time
from typing import Any
@ -78,7 +78,7 @@ class Node(BaseModel, ABC):
name: str = Field(description='name of the node')
group_id: str = Field(description='partition of the graph')
labels: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now())
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@abstractmethod
async def save(self, driver: AsyncDriver): ...

View File

@ -18,7 +18,7 @@ import asyncio
import logging
import typing
from collections import defaultdict
from datetime import datetime
from datetime import datetime, timezone
from math import ceil
from neo4j import AsyncDriver, AsyncManagedTransaction
@ -385,7 +385,7 @@ async def extract_edge_dates_bulk(
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = datetime.now()
edge.expired_at = datetime.now(timezone.utc)
return edges

View File

@ -1,7 +1,7 @@
import asyncio
import logging
from collections import defaultdict
from datetime import datetime
from datetime import datetime, timezone
from neo4j import AsyncDriver
from pydantic import BaseModel
@ -178,7 +178,7 @@ async def build_community(
summary = summaries[0]
name = await generate_summary_description(llm_client, summary)
now = datetime.now()
now = datetime.now(timezone.utc)
community_node = CommunityNode(
name=name,
group_id=community_cluster[0].group_id,

View File

@ -16,7 +16,7 @@ limitations under the License.
import asyncio
import logging
from datetime import datetime
from datetime import datetime, timezone
from time import time
from typing import List
@ -110,7 +110,7 @@ async def extract_edges(
group_id=group_id,
fact=edge_data['fact'],
episodes=[episode.uuid],
created_at=datetime.now(),
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
@ -205,39 +205,9 @@ async def resolve_extracted_edges(
return resolved_edges, invalidated_edges
async def resolve_extracted_edge(
llm_client: LLMClient,
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
)
now = datetime.now()
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
if invalid_at is not None and resolved_edge.expired_at is None:
resolved_edge.expired_at = now
# Determine if the new_edge needs to be expired
if resolved_edge.expired_at is None:
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
for candidate in invalidation_candidates:
if (
candidate.valid_at is not None and resolved_edge.valid_at is not None
) and candidate.valid_at > resolved_edge.valid_at:
# Expire new edge since we have information about more recent events
resolved_edge.invalid_at = candidate.valid_at
resolved_edge.expired_at = now
break
def resolve_edge_contradictions(
resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
) -> list[EntityEdge]:
# Determine which contradictory edges need to be expired
invalidated_edges: list[EntityEdge] = []
for edge in invalidation_candidates:
@ -259,9 +229,50 @@ async def resolve_extracted_edge(
and edge.valid_at < resolved_edge.valid_at
):
edge.invalid_at = resolved_edge.valid_at
edge.expired_at = edge.expired_at if edge.expired_at is not None else now
edge.expired_at = (
edge.expired_at if edge.expired_at is not None else datetime.now(timezone.utc)
)
invalidated_edges.append(edge)
return invalidated_edges
async def resolve_extracted_edge(
llm_client: LLMClient,
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
get_edge_contradictions(llm_client, extracted_edge, existing_edges),
)
now = datetime.now(timezone.utc)
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
if invalid_at is not None and resolved_edge.expired_at is None:
resolved_edge.expired_at = now
# Determine if the new_edge needs to be expired
if resolved_edge.expired_at is None:
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
for candidate in invalidation_candidates:
if (
candidate.valid_at is not None and resolved_edge.valid_at is not None
) and candidate.valid_at > resolved_edge.valid_at:
# Expire new edge since we have information about more recent events
resolved_edge.invalid_at = candidate.valid_at
resolved_edge.expired_at = now
break
# Determine which contradictory edges need to be expired
invalidated_edges = resolve_edge_contradictions(resolved_edge, invalidation_candidates)
return resolved_edge, invalidated_edges

View File

@ -16,7 +16,7 @@ limitations under the License.
import asyncio
import logging
from datetime import datetime
from datetime import datetime, timezone
from time import time
from typing import Any
@ -113,7 +113,7 @@ async def extract_nodes(
group_id=episode.group_id,
labels=node_data['labels'],
summary=node_data['summary'],
created_at=datetime.now(),
created_at=datetime.now(timezone.utc),
)
new_nodes.append(new_node)
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.3.21"
version = "0.4.0"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",

View File

@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timezone
from fastapi import APIRouter, status
@ -36,7 +36,7 @@ async def get_entity_edge(uuid: str, graphiti: ZepGraphitiDep):
@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK)
async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep):
episodes = await graphiti.retrieve_episodes(
group_ids=[group_id], last_n=last_n, reference_time=datetime.now()
group_ids=[group_id], last_n=last_n, reference_time=datetime.now(timezone.utc)
)
return episodes

View File

@ -18,7 +18,7 @@ import asyncio
import logging
import os
import sys
from datetime import datetime
from datetime import datetime, timezone
import pytest
from dotenv import load_dotenv
@ -66,7 +66,39 @@ def setup_logging():
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
episodes = await graphiti.retrieve_episodes(datetime.now(), group_ids=None)
now = datetime.now(timezone.utc)
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
group_id='test',
)
bob_node = EntityNode(
name='Bob',
labels=[],
created_at=now,
summary='Bob summary',
group_id='test',
)
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
group_id='test',
)
await graphiti.add_triplet(alice_node, entity_edge, bob_node)
episodes = await graphiti.retrieve_episodes(datetime.now(timezone.utc), group_ids=None)
episode_uuids = [episode.uuid for episode in episodes]
results = await graphiti._search(
@ -92,7 +124,7 @@ async def test_graph_integration():
embedder = client.embedder
driver = client.driver
now = datetime.now()
now = datetime.now(timezone.utc)
episode = EpisodicNode(
name='test_episode',
labels=[],

View File

@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -23,7 +23,7 @@ def mock_extracted_edge():
group_id='group_1',
fact='Test fact',
episodes=['episode_1'],
created_at=datetime.now(),
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
@ -39,8 +39,8 @@ def mock_related_edges():
group_id='group_1',
fact='Related fact',
episodes=['episode_2'],
created_at=datetime.now() - timedelta(days=1),
valid_at=datetime.now() - timedelta(days=1),
created_at=datetime.now(timezone.utc) - timedelta(days=1),
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
invalid_at=None,
)
]
@ -56,8 +56,8 @@ def mock_existing_edges():
group_id='group_1',
fact='Existing fact',
episodes=['episode_3'],
created_at=datetime.now() - timedelta(days=2),
valid_at=datetime.now() - timedelta(days=2),
created_at=datetime.now(timezone.utc) - timedelta(days=2),
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
invalid_at=None,
)
]
@ -68,7 +68,7 @@ def mock_current_episode():
return EpisodicNode(
uuid='episode_1',
content='Current episode content',
valid_at=datetime.now(),
valid_at=datetime.now(timezone.utc),
name='Current Episode',
group_id='group_1',
source='message',
@ -82,7 +82,7 @@ def mock_previous_episodes():
EpisodicNode(
uuid='episode_2',
content='Previous episode content',
valid_at=datetime.now() - timedelta(days=1),
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
name='Previous Episode',
group_id='group_1',
source='message',
@ -144,8 +144,8 @@ async def test_resolve_extracted_edge_with_dates(
mock_previous_episodes,
monkeypatch: MonkeyPatch,
):
valid_at = datetime.now() - timedelta(days=1)
invalid_at = datetime.now() + timedelta(days=1)
valid_at = datetime.now(timezone.utc) - timedelta(days=1)
invalid_at = datetime.now(timezone.utc) + timedelta(days=1)
# Mock the function calls
dedupe_mock = AsyncMock(return_value=mock_extracted_edge)
@ -189,7 +189,7 @@ async def test_resolve_extracted_edge_with_invalidation(
mock_previous_episodes,
monkeypatch: MonkeyPatch,
):
valid_at = datetime.now() - timedelta(days=1)
valid_at = datetime.now(timezone.utc) - timedelta(days=1)
mock_extracted_edge.valid_at = valid_at
invalidation_candidate = EntityEdge(
@ -199,8 +199,8 @@ async def test_resolve_extracted_edge_with_invalidation(
group_id='group_1',
fact='Invalidation candidate fact',
episodes=['episode_4'],
created_at=datetime.now(),
valid_at=datetime.now() - timedelta(days=2),
created_at=datetime.now(timezone.utc),
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
invalid_at=None,
)

View File

@ -15,11 +15,10 @@ limitations under the License.
"""
import os
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
import pytest
from dotenv import load_dotenv
from pytz import UTC
from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMConfig, OpenAIClient
@ -43,7 +42,7 @@ def setup_llm_client():
def create_test_data():
now = datetime.now()
now = datetime.now(timezone.utc)
# Create edges
existing_edge = EntityEdge(
@ -132,7 +131,7 @@ async def test_get_edge_contradictions_multiple_existing():
# Helper function to create more complex test data
def create_complex_test_data():
now = datetime.now()
now = datetime.now(timezone.utc)
# Create nodes
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
@ -192,7 +191,7 @@ async def test_invalidate_edges_complex():
name='DISLIKES',
fact='Alice dislikes Bob',
group_id='1',
created_at=datetime.now(),
created_at=datetime.now(timezone.utc),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -214,7 +213,7 @@ async def test_get_edge_contradictions_temporal_update():
name='LEFT_JOB',
fact='Bob no longer works at at Company XYZ',
group_id='1',
created_at=datetime.now(),
created_at=datetime.now(timezone.utc),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -236,7 +235,7 @@ async def test_get_edge_contradictions_no_effect():
name='APPLIED_TO',
fact='Charlie applied to Company XYZ',
group_id='1',
created_at=datetime.now(),
created_at=datetime.now(timezone.utc),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -257,7 +256,7 @@ async def test_invalidate_edges_partial_update():
name='CHANGED_POSITION',
fact='Bob changed his position at Company XYZ',
group_id='1',
created_at=datetime.now(),
created_at=datetime.now(timezone.utc),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
@ -266,7 +265,7 @@ async def test_invalidate_edges_partial_update():
def create_data_for_temporal_extraction() -> tuple[EpisodicNode, list[EpisodicNode]]:
now = datetime.now(UTC)
now = datetime.now(timezone.utc)
previous_episodes = [
EpisodicNode(
@ -315,7 +314,7 @@ async def test_extract_edge_dates():
name='LEFT_JOB',
fact='Bob no longer works at Company XYZ',
group_id='1',
created_at=datetime.now(UTC),
created_at=datetime.now(timezone.utc),
)
valid_at, invalid_at = await extract_edge_dates(