Controlled example (#37)

* chore: Add romeo runner

* fix: Linter

* dedupe fixes

* wip

* wip dump

* allbirds

* chore: Update romeo parser

* chore: Anthropic model fix

* allbirds runner

* format

* wip

* mypy updates

* update

* remove r

* update tests

* format

* wip

* wip

* wip

* chore: Strategically update the message

* chore: Add romeo runner

* fix: Linter

* wip

* wip dump

* chore: Update romeo parser

* chore: Anthropic model fix

* wip

* allbirds

* allbirds runner

* format

* wip

* wip

* mypy updates

* update

* remove r

* update tests

* format

* wip

* chore: Strategically update the message

* rebase and fix import issues

* Update package imports for graphiti_core in examples and utils

* nits

* chore: Update OpenAI GPT-4o model to gpt-4o-2024-08-06

* implement groq

* improvments & linting

* cleanup and nits

* Refactor package imports for graphiti_core in examples and utils

* Refactor package imports for graphiti_core in examples and utils

* chore: Nuke unused examples

* chore: Nuke unused examples

* chore: Only run type check on graphiti_core

* fix unit tests

* reformat

* unit test

* fix: Unit tests

* test: Add coverage for extract_date_strings_from_edge

* lint

* remove commented code

---------

Co-authored-by: prestonrasmussen <prasmuss15@gmail.com>
Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
This commit is contained in:
Pavlo Paliychuk 2024-08-26 10:30:22 -04:00 committed by GitHub
parent c5e52153c4
commit 0ed7739bc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 8120 additions and 275 deletions

View File

@ -37,7 +37,7 @@ jobs:
shell: bash
run: |
set -o pipefail
poetry run mypy . --show-column-numbers --show-error-codes | sed -E '
poetry run mypy ./graphiti_core --show-column-numbers --show-error-codes | sed -E '
s/^(.*):([0-9]+):([0-9]+): (error|warning): (.+) \[(.+)\]/::error file=\1,line=\2,endLine=\2,col=\3,title=\6::\5/;
s/^(.*):([0-9]+):([0-9]+): note: (.+)/::notice file=\1,line=\2,endLine=\2,col=\3,title=Note::\4/;
'

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,117 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import json
import logging
import os
import sys
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
from graphiti_core import Graphiti
from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
shoe_conversation = [
"SalesBot: Hi, I'm Allbirds Assistant! How can I help you today?",
"John: Hi, I'm looking for a new pair of shoes.",
'SalesBot: Of course! What kinde of material are you looking for?',
"John: I'm looking for shoes made out of wool",
"""SalesBot: We have just what you are looking for, how do you like our Men's SuperLight Wool Runners
- Dark Grey (Medium Grey Sole)? They use the SuperLight Foam technology.""",
"""John: Oh, actually I bought those 2 months ago, but unfortunately found out that I was allergic to wool.
I think I will pass on those, maybe there is something with a retro look that you could suggest?""",
"""SalesBot: Im sorry to hear that! Would you be interested in Men's Couriers -
(Blizzard Sole) model? We have them in Natural Black and Basin Blue colors""",
'John: Oh that is perfect, I LOVE the Natural Black color!. I will take those.',
]
async def add_messages(client: Graphiti):
for i, message in enumerate(shoe_conversation):
await client.add_episode(
name=f'Message {i}',
episode_body=message,
source=EpisodeType.message,
reference_time=datetime.now(),
source_description='Shoe conversation',
)
async def main():
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
await clear_data(client.driver)
await client.build_indices_and_constraints()
await ingest_products_data(client)
await add_messages(client)
async def ingest_products_data(client: Graphiti):
script_dir = Path(__file__).parent
json_file_path = script_dir / 'allbirds_products.json'
with open(json_file_path) as file:
products = json.load(file)['products']
episodes: list[RawEpisode] = [
RawEpisode(
name=f'Product {i}',
content=str(product),
source_description='Allbirds products',
source=EpisodeType.json,
reference_time=datetime.now(),
)
for i, product in enumerate(products)
]
await client.add_episode_bulk(episodes)
asyncio.run(main())

View File

@ -23,7 +23,8 @@ from dotenv import load_dotenv
from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti
from graphiti_core.utils.bulk_utils import BulkEpisode
from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
@ -70,12 +71,12 @@ async def main(use_bulk: bool = True):
source_description='Podcast Transcript',
)
episodes: list[BulkEpisode] = [
BulkEpisode(
episodes: list[RawEpisode] = [
RawEpisode(
name=f'Message {i}',
content=f'{message.speaker_name} ({message.role}): {message.content}',
source=EpisodeType.message,
source_description='Podcast Transcript',
episode_type='string',
reference_time=message.actual_timestamp,
)
for i, message in enumerate(messages[3:14])

View File

@ -0,0 +1,36 @@
import os
import re
def parse_wizard_of_oz(file_path):
with open(file_path, encoding='utf-8') as file:
content = file.read()
# Split the content into chapters
chapters = re.split(r'\n\n+Chapter [IVX]+\n', content)[
1:
] # Skip the first split which is before Chapter I
episodes = []
for i, chapter in enumerate(chapters, start=1):
# Extract chapter title
title_match = re.match(r'(.*?)\n\n', chapter)
title = title_match.group(1) if title_match else f'Chapter {i}'
# Remove the title from the chapter content
chapter_content = chapter[len(title) :].strip() if title_match else chapter.strip()
# Create episode dictionary
episode = {'episode_number': i, 'title': title, 'content': chapter_content}
episodes.append(episode)
return episodes
def get_wizard_of_oz_messages():
file_path = 'woo.txt'
script_dir = os.path.dirname(__file__)
relative_path = os.path.join(script_dir, file_path)
# Use the function
parsed_episodes = parse_wizard_of_oz(relative_path)
return parsed_episodes

View File

@ -0,0 +1,93 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import logging
import os
import sys
from datetime import datetime, timedelta
from dotenv import load_dotenv
from examples.wizard_of_oz.parser import get_wizard_of_oz_messages
from graphiti_core import Graphiti
from graphiti_core.llm_client.anthropic_client import AnthropicClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
async def main():
setup_logging()
llm_client = AnthropicClient(LLMConfig(api_key=os.environ.get('ANTHROPIC_API_KEY')))
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password, llm_client)
messages = get_wizard_of_oz_messages()
print(messages)
print(len(messages))
now = datetime.now()
# episodes: list[BulkEpisode] = [
# BulkEpisode(
# name=f'Chapter {i + 1}',
# content=chapter['content'],
# source_description='Wizard of Oz Transcript',
# episode_type='string',
# reference_time=now + timedelta(seconds=i * 10),
# )
# for i, chapter in enumerate(messages[0:50])
# ]
# await clear_data(client.driver)
# await client.build_indices_and_constraints()
# await client.add_episode_bulk(episodes)
await clear_data(client.driver)
await client.build_indices_and_constraints()
for i, chapter in enumerate(messages):
await client.add_episode(
name=f'Chapter {i + 1}',
episode_body=chapter['content'],
source_description='Wizard of Oz Transcript',
reference_time=now + timedelta(seconds=i * 10),
)
asyncio.run(main())

File diff suppressed because it is too large Load Diff

View File

@ -26,7 +26,7 @@ from neo4j import AsyncGraphDatabase
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient, LLMConfig, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, hybrid_search
from graphiti_core.search.search_utils import (
get_relevant_edges,
@ -37,7 +37,7 @@ from graphiti_core.utils import (
retrieve_episodes,
)
from graphiti_core.utils.bulk_utils import (
BulkEpisode,
RawEpisode,
dedupe_edges_bulk,
dedupe_nodes_bulk,
extract_nodes_and_edges_bulk,
@ -74,8 +74,7 @@ class Graphiti:
self.llm_client = OpenAIClient(
LLMConfig(
api_key=os.getenv('OPENAI_API_KEY', default=''),
model='gpt-4o-mini',
base_url='https://api.openai.com/v1',
model='gpt-4o-2024-08-06',
)
)
@ -99,6 +98,7 @@ class Graphiti:
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
@ -112,11 +112,11 @@ class Graphiti:
embedder = self.llm_client.get_embedder()
now = datetime.now()
previous_episodes = await self.retrieve_episodes(reference_time)
previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
episode = EpisodicNode(
name=name,
labels=[],
source='messages',
source=source,
content=episode_body,
source_description=source_description,
created_at=now,
@ -149,12 +149,6 @@ class Graphiti:
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
# deduped_edges = await dedupe_extracted_edges_v2(
# self.llm_client,
# extract_node_and_edge_triplets(extracted_edges, nodes),
# extract_node_and_edge_triplets(existing_edges, nodes),
# )
deduped_edges = await dedupe_extracted_edges(
self.llm_client,
extracted_edges,
@ -166,6 +160,30 @@ class Graphiti:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
for edge in deduped_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode.valid_at,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = datetime.now()
for edge in existing_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode.valid_at,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = datetime.now()
(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
@ -190,19 +208,10 @@ class Graphiti:
deduped_edge.expired_at = edge.expired_at
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
edges_to_save = existing_edges + deduped_edges
for edge_to_extract_dates_from in edges_to_save:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge_to_extract_dates_from,
episode.valid_at,
episode,
previous_episodes,
)
edge_to_extract_dates_from.valid_at = valid_at
edge_to_extract_dates_from.invalid_at = invalid_at
entity_edges.extend(edges_to_save)
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
@ -210,8 +219,6 @@ class Graphiti:
logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
episodic_edges.extend(
@ -232,7 +239,7 @@ class Graphiti:
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
end = time()
logger.info(f'Completed add_episode in {(end-start) * 1000} ms')
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)
@ -246,7 +253,7 @@ class Graphiti:
async def add_episode_bulk(
self,
bulk_episodes: list[BulkEpisode],
bulk_episodes: list[RawEpisode],
):
try:
start = time()
@ -257,7 +264,7 @@ class Graphiti:
EpisodicNode(
name=episode.name,
labels=[],
source='messages',
source=episode.source,
content=episode.content,
source_description=episode.source_description,
created_at=now,
@ -316,7 +323,7 @@ class Graphiti:
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
end = time()
logger.info(f'Completed add_episode_bulk in {(end-start) * 1000} ms')
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
except Exception as e:
raise e

View File

@ -1,5 +1,7 @@
from .anthropic_client import AnthropicClient
from .client import LLMClient
from .config import LLMConfig
from .groq_client import GroqClient
from .openai_client import OpenAIClient
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'AnthropicClient', 'GroqClient']

View File

@ -0,0 +1,60 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import typing
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
logger = logging.getLogger(__name__)
class AnthropicClient(LLMClient):
def __init__(self, config: LLMConfig | None = None):
if config is None:
config = LLMConfig()
self.client = AsyncAnthropic(api_key=config.api_key)
self.model = config.model
def get_embedder(self) -> typing.Any:
openai_client = AsyncOpenAI()
return openai_client.embeddings
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
system_message = messages[0]
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
{'role': 'assistant', 'content': '{'}
]
try:
result = await self.client.messages.create(
system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
+ system_message.content,
max_tokens=4096,
messages=user_messages, # type: ignore
model='claude-3-5-sonnet-20240620',
)
return json.loads('{' + result.content[0].text) # type: ignore
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise

View File

@ -23,7 +23,7 @@ from .config import LLMConfig
class LLMClient(ABC):
@abstractmethod
def __init__(self, config: LLMConfig):
def __init__(self, config: LLMConfig | None):
pass
@abstractmethod

View File

@ -28,9 +28,9 @@ class LLMConfig:
def __init__(
self,
api_key: str,
api_key: str | None = None,
model: str = 'gpt-4o-mini',
base_url: str = 'https://api.openai.com',
base_url: str = 'https://api.openai.com/v1',
):
"""
Initialize the LLMConfig with the provided parameters.

View File

@ -0,0 +1,63 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import typing
from groq import AsyncGroq
from groq.types.chat import ChatCompletionMessageParam
from openai import AsyncOpenAI
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
logger = logging.getLogger(__name__)
class GroqClient(LLMClient):
def __init__(self, config: LLMConfig | None = None):
if config is None:
config = LLMConfig()
self.client = AsyncGroq(api_key=config.api_key)
self.model = config.model
def get_embedder(self) -> typing.Any:
openai_client = AsyncOpenAI()
return openai_client.embeddings
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content})
try:
response = await self.client.chat.completions.create(
model='llama-3.1-70b-versatile',
messages=openai_messages,
temperature=0.0,
max_tokens=4096,
response_format={'type': 'json_object'},
)
result = response.choices[0].message.content or ''
return json.loads(result)
except Exception as e:
print(openai_messages)
logger.error(f'Error in generating LLM response: {e}')
raise

View File

@ -29,7 +29,9 @@ logger = logging.getLogger(__name__)
class OpenAIClient(LLMClient):
def __init__(self, config: LLMConfig):
def __init__(self, config: LLMConfig | None = None):
if config is None:
config = LLMConfig()
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
self.model = config.model
@ -48,11 +50,12 @@ class OpenAIClient(LLMClient):
model=self.model,
messages=openai_messages,
temperature=0,
max_tokens=3000,
max_tokens=4096,
response_format={'type': 'json_object'},
)
result = response.choices[0].message.content or ''
return json.loads(result)
except Exception as e:
print(openai_messages)
logger.error(f'Error in generating LLM response: {e}')
raise

View File

@ -17,6 +17,7 @@ limitations under the License.
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from time import time
from uuid import uuid4
@ -29,6 +30,20 @@ from graphiti_core.llm_client.config import EMBEDDING_DIM
logger = logging.getLogger(__name__)
class EpisodeType(Enum):
message = 'message'
json = 'json'
@staticmethod
def from_str(episode_type: str):
if episode_type == 'message':
return EpisodeType.message
if episode_type == 'json':
return EpisodeType.json
logger.error(f'Episode type: {episode_type} not implemented')
raise NotImplementedError
class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: uuid4().hex)
name: str
@ -48,7 +63,7 @@ class Node(BaseModel, ABC):
class EpisodicNode(Node):
source: str = Field(description='source type')
source: EpisodeType = Field(description='source type')
source_description: str = Field(description='description of the data source')
content: str = Field(description='raw episode data')
valid_at: datetime = Field(
@ -73,7 +88,7 @@ class EpisodicNode(Node):
entity_edges=self.entity_edges,
created_at=self.created_at,
valid_at=self.valid_at,
source=self.source,
source=self.source.value,
_database='neo4j',
)
@ -96,7 +111,7 @@ class EntityNode(Node):
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.name_embedding = embedding[:EMBEDDING_DIM]
end = time()
logger.info(f'embedded {text} in {end-start} ms')
logger.info(f'embedded {text} in {end - start} ms')
return embedding

View File

@ -88,7 +88,8 @@ def v2(context: dict[str, Any]) -> list[Message]:
New Nodes:
{json.dumps(context['extracted_nodes'], indent=2)}
Important:
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
Task:
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list

View File

@ -29,7 +29,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
Reference Timestamp: {context['reference_timestamp']}
IMPORTANT: Only extract time information if it is part of the provided fact. Otherwise ignore the time mentioned. Make sure to do your best to determine the dates if only the relative time is mentioned. (eg 10 years ago, 2 mins ago) based on the provided reference timestamp
If the relationship is not of spanning nature, but you are still able to determine the dates, set the valid_at only.
Definitions:
- valid_at: The date and time when the relationship described by the edge fact became true or was established.
- invalid_at: The date and time when the relationship described by the edge fact stopped being true or ended.

View File

@ -23,64 +23,16 @@ from .models import Message, PromptFunction, PromptVersion
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
v3: PromptVersion
extract_json: PromptVersion
class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
v3: PromptFunction
extract_json: PromptFunction
def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that extracts graph nodes from provided context.',
),
Message(
role='user',
content=f"""
Given the following context, extract new semantic nodes that need to be added to the knowledge graph:
Existing Nodes:
{json.dumps(context['existing_nodes'], indent=2)}
Previous Episodes:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
New Episode:
Content: {context["episode_content"]}
Timestamp: {context['episode_timestamp']}
Extract new semantic nodes based on the content of the current episode, while considering the existing nodes and context from previous episodes.
Guidelines:
1. Only extract new nodes that don't already exist in the graph structure.
2. Focus on entities, concepts, or actors that are central to the current episode.
3. Avoid creating nodes for relationships or actions (these will be handled as edges later).
4. Provide a brief but informative summary for each node.
5. If a node seems to represent an existing concept but with updated information, don't create a new node. This will be handled by edge updates.
6. Do not create nodes for episodic content (like Message 1 or Message 2).
Respond with a JSON object in the following format:
{{
"new_nodes": [
{{
"name": "Unique identifier for the node",
"labels": ["Semantic", "OptionalAdditionalLabel"],
"summary": "Brief summary of the node's role or significance"
}}
]
}}
If no new nodes need to be added, return an empty list for "new_nodes".
""",
),
]
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
@ -121,7 +73,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
]
def v3(context: dict[str, Any]) -> list[Message]:
def v2(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
user_prompt = f"""
@ -141,7 +93,7 @@ Guidelines:
Respond with a JSON object in the following format:
{{
"new_nodes": [
"extracted_nodes": [
{{
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
"labels": ["Entity", "Speaker" for speaker nodes, "OptionalAdditionalLabel"],
@ -156,4 +108,38 @@ Respond with a JSON object in the following format:
]
versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3}
def extract_json(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text.
Your primary task is to identify and extract relevant entities from JSON files"""
user_prompt = f"""
Given the following source description, extract relevant entity nodes from the provided JSON:
Source Description:
{context["source_description"]}
JSON:
{context["episode_content"]}
Guidelines:
1. Always try to extract an entities that the JSON represents. This will often be something like a "name" or "user field
2. Do NOT extract any properties that contain dates
Respond with a JSON object in the following format:
{{
"extracted_nodes": [
{{
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
"labels": ["Entity", "Speaker" for speaker nodes, "OptionalAdditionalLabel"],
"summary": "Brief summary of the node's role or significance"
}}
]
}}
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]
versions: Versions = {'v1': v1, 'v2': v2, 'extract_json': extract_json}

View File

@ -36,9 +36,10 @@ def v1(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
Do not invalidate relationships merely because they weren't mentioned in new edges. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships.
Based on the provided existing edges and new edges with their timestamps, determine which relationships, if any, should be marked as expired due to contradictions or updates in the newer edges.
Use the start and end dates of the edges to determine which edges are to be marked expired.
Only mark a relationship as invalid if there is clear evidence from other edges that the relationship is no longer true.
Do not invalidate relationships merely because they weren't mentioned in the episodes. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships.
Previous Episodes:
{context['previous_episodes']}
@ -52,7 +53,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
New Edges:
{context['new_edges']}
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), TIMESTAMP)"
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), START_DATE (END_DATE, optional))"
For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{

View File

@ -1,5 +1,6 @@
import asyncio
import logging
import re
import typing
from collections import defaultdict
from datetime import datetime
@ -186,7 +187,7 @@ async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = query + '~'
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
@ -221,7 +222,7 @@ async def edge_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = query + '~'
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
records, _, _ = await driver.execute_query(
"""

View File

@ -24,7 +24,7 @@ from pydantic import BaseModel
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
from graphiti_core.utils import retrieve_episodes
from graphiti_core.utils.maintenance.edge_operations import (
@ -43,11 +43,11 @@ from graphiti_core.utils.maintenance.node_operations import (
CHUNK_SIZE = 15
class BulkEpisode(BaseModel):
class RawEpisode(BaseModel):
name: str
content: str
source_description: str
episode_type: str
source: EpisodeType
reference_time: datetime

View File

@ -70,6 +70,7 @@ async def extract_edges(
}
llm_response = await llm_client.generate_response(prompt_library.extract_edges.v2(context))
print(llm_response)
edges_data = llm_response.get('edges', [])
end = time()

View File

@ -21,7 +21,7 @@ from datetime import datetime, timezone
from neo4j import AsyncDriver
from typing_extensions import LiteralString
from graphiti_core.nodes import EpisodicNode
from graphiti_core.nodes import EpisodeType, EpisodicNode
EPISODE_WINDOW_LEN = 3
@ -112,7 +112,7 @@ async def retrieve_episodes(
),
valid_at=(record['valid_at'].to_native()),
uuid=record['uuid'],
source=record['source'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],
)

View File

@ -17,42 +17,71 @@ limitations under the License.
import logging
from datetime import datetime
from time import time
from typing import Any
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.prompts import prompt_library
logger = logging.getLogger(__name__)
async def extract_message_nodes(
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
) -> list[dict[str, Any]]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'previous_episodes': [
{
'content': ep.content,
'timestamp': ep.valid_at.isoformat(),
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v2(context))
extracted_node_data = llm_response.get('extracted_nodes', [])
return extracted_node_data
async def extract_json_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
) -> list[dict[str, Any]]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'source_description': episode.source_description,
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_json(context)
)
extracted_node_data = llm_response.get('extracted_nodes', [])
return extracted_node_data
async def extract_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> list[EntityNode]:
start = time()
# Prepare context for LLM
context = {
'episode_content': episode.content,
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
'previous_episodes': [
{
'content': ep.content,
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
}
for ep in previous_episodes
],
}
llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v3(context))
new_nodes_data = llm_response.get('new_nodes', [])
extracted_node_data: list[dict[str, Any]] = []
if episode.source == EpisodeType.message:
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
elif episode.source == EpisodeType.json:
extracted_node_data = await extract_json_nodes(llm_client, episode)
end = time()
logger.info(f'Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms')
logger.info(f'Extracted new nodes: {extracted_node_data} in {(end - start) * 1000} ms')
# Convert the extracted data into EntityNode objects
new_nodes = []
for node_data in new_nodes_data:
for node_data in extracted_node_data:
new_node = EntityNode(
name=node_data['name'],
labels=node_data['labels'],

View File

@ -91,6 +91,15 @@ async def invalidate_edges(
return invalidated_edges
def extract_date_strings_from_edge(edge: EntityEdge) -> str:
start = edge.valid_at
end = edge.invalid_at
date_string = f'Start Date: {start.isoformat()}' if start else ''
if end:
date_string += f' (End Date: {end.isoformat()})'
return date_string
def prepare_invalidation_context(
existing_edges: list[NodeEdgeNodeTriplet],
new_edges: list[NodeEdgeNodeTriplet],
@ -99,15 +108,15 @@ def prepare_invalidation_context(
) -> dict:
return {
'existing_edges': [
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})'
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
for source_node, edge, target_node in sorted(
existing_edges, key=lambda x: x[1].created_at, reverse=True
existing_edges, key=lambda x: (x[1].created_at), reverse=True
)
],
'new_edges': [
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})'
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
for source_node, edge, target_node in sorted(
new_edges, key=lambda x: x[1].created_at, reverse=True
new_edges, key=lambda x: (x[1].created_at), reverse=True
)
],
'current_episode': current_episode.content,

1613
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -22,6 +22,7 @@ sentence-transformers = "^3.0.1"
diskcache = "^5.6.3"
arrow = "^1.3.0"
openai = "^1.38.0"
anthropic = "^0.34.1"
[tool.poetry.dev-dependencies]
pytest = "^8.3.2"
@ -33,6 +34,9 @@ ruff = "^0.6.2"
[tool.poetry.group.dev.dependencies]
pydantic = "^2.8.2"
mypy = "^1.11.1"
groq = "^0.9.0"
ipykernel = "^6.29.5"
jupyterlab = "^4.2.4"
[build-system]
requires = ["poetry-core"]

131
runner.py
View File

@ -1,131 +0,0 @@
import asyncio
import logging
import os
import sys
from datetime import datetime
from dotenv import load_dotenv
from graphiti_core import Graphiti
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
bmw_sales = [
{
'episode_body': 'Paul (buyer): Hi, I would like to buy a new car',
},
{
'episode_body': 'Dan The Salesman (salesman): Sure, I can help you with that. What kind of car are you looking for?',
},
{
'episode_body': 'Paul (buyer): I am looking for a new BMW',
},
{
'episode_body': 'Dan The Salesman (salesman): Great choice! What kind of BMW are you looking for?',
},
{
'episode_body': 'Paul (buyer): I am considering a BMW 3 series',
},
{
'episode_body': 'Dan The Salesman (salesman): Great choice, we currently have a 2024 BMW 3 series in stock, it is a great car and costs $50,000',
},
{
'episode_body': "Paul (buyer): Actually I am interested in something cheaper, I won't consider anything over $30,000",
},
]
dates_mentioned = [
{
'episode_body': 'Paul (user): I have graduated from Univerity of Toronto in 2022',
},
{
'episode_body': 'Jane (user): How cool, I graduated from the same school in 1999',
},
]
times_mentioned = [
{
'episode_body': 'Paul (user): 15 minutes ago we put a deposit on our new house',
},
]
time_range_mentioned = [
{
'episode_body': 'Paul (user): I served as a US Marine in 2015-2019',
},
]
relative_time_range_mentioned = [
{
'episode_body': 'Paul (user): I lived in Toronto for 10 years, until moving to Vancouver yesterday',
},
]
async def main():
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
await clear_data(client.driver)
await client.build_indices_and_constraints()
for i, message in enumerate(bmw_sales):
await client.add_episode(
name=f'Message {i}',
episode_body=message['episode_body'],
source_description='',
# reference_time=datetime.now() - timedelta(days=365 * 3),
reference_time=datetime.now(),
)
# await client.add_episode(
# name='Message 5',
# episode_body='Jane: I miss Paul',
# source_description='WhatsApp Message',
# reference_time=datetime.now(),
# )
# await client.add_episode(
# name='Message 6',
# episode_body='Jane: I dont miss Paul anymore, I hate him',
# source_description='WhatsApp Message',
# reference_time=datetime.now(),
# )
# await client.add_episode(
# name="Message 3",
# episode_body="Assistant: The best type of apples available are Fuji apples",
# source_description="WhatsApp Message",
# )
# await client.add_episode(
# name="Message 4",
# episode_body="Paul: Oh, I actually hate those",
# source_description="WhatsApp Message",
# )
asyncio.run(main())

View File

@ -1,10 +1,12 @@
import unittest
from datetime import datetime, timedelta
import pytest
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.maintenance.temporal_operations import (
extract_date_strings_from_edge,
prepare_edges_for_invalidation,
prepare_invalidation_context,
)
@ -153,7 +155,7 @@ def test_prepare_invalidation_context():
content='This is the current episode content.',
created_at=now,
valid_at=now,
source='test',
source=EpisodeType.message,
source_description='Test episode for unit testing',
)
previous_episodes = [
@ -162,7 +164,7 @@ def test_prepare_invalidation_context():
content='This is the content of previous episode 1.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source=EpisodeType.message,
source_description='Test previous episode 1 for unit testing',
),
EpisodicNode(
@ -170,7 +172,7 @@ def test_prepare_invalidation_context():
content='This is the content of previous episode 2.',
created_at=now - timedelta(days=2),
valid_at=now - timedelta(days=2),
source='test',
source=EpisodeType.message,
source_description='Test previous episode 2 for unit testing',
),
]
@ -197,7 +199,7 @@ def test_prepare_invalidation_context():
assert node1.name in existing_edge_str
assert edge1.name in existing_edge_str
assert node2.name in existing_edge_str
assert edge1.created_at.isoformat() in existing_edge_str
assert edge1.fact in existing_edge_str
# Check the format of the new edge
new_edge_str = result['new_edges'][0]
@ -205,7 +207,7 @@ def test_prepare_invalidation_context():
assert node2.name in new_edge_str
assert edge2.name in new_edge_str
assert node3.name in new_edge_str
assert edge2.created_at.isoformat() in new_edge_str
assert edge2.fact in new_edge_str
def test_prepare_invalidation_context_empty_input():
@ -215,7 +217,7 @@ def test_prepare_invalidation_context_empty_input():
content='Empty episode',
created_at=now,
valid_at=now,
source='test',
source=EpisodeType.message,
source_description='Test empty episode for unit testing',
)
result = prepare_invalidation_context([], [], current_episode, [])
@ -267,7 +269,7 @@ def test_prepare_invalidation_context_sorting():
content='This is the current episode content.',
created_at=now,
valid_at=now,
source='test',
source=EpisodeType.message,
source_description='Test episode for unit testing',
)
previous_episodes = [
@ -276,7 +278,7 @@ def test_prepare_invalidation_context_sorting():
content='This is the content of a previous episode.',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
),
]
@ -293,6 +295,43 @@ def test_prepare_invalidation_context_sorting():
assert result['previous_episodes'][0] == previous_episodes[0].content
class TestExtractDateStringsFromEdge(unittest.TestCase):
def generate_entity_edge(self, valid_at, invalid_at):
return EntityEdge(
source_node_uuid='1',
target_node_uuid='2',
name='KNOWS',
fact='Node1 knows Node2',
created_at=datetime.now(),
valid_at=valid_at,
invalid_at=invalid_at,
)
def test_both_dates_present(self):
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), datetime(2024, 1, 2, 12, 0))
result = extract_date_strings_from_edge(edge)
expected = 'Start Date: 2024-01-01T12:00:00 (End Date: 2024-01-02T12:00:00)'
self.assertEqual(result, expected)
def test_only_valid_at_present(self):
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), None)
result = extract_date_strings_from_edge(edge)
expected = 'Start Date: 2024-01-01T12:00:00'
self.assertEqual(result, expected)
def test_only_invalid_at_present(self):
edge = self.generate_entity_edge(None, datetime(2024, 1, 2, 12, 0))
result = extract_date_strings_from_edge(edge)
expected = ' (End Date: 2024-01-02T12:00:00)'
self.assertEqual(result, expected)
def test_no_dates_present(self):
edge = self.generate_entity_edge(None, None)
result = extract_date_strings_from_edge(edge)
expected = ''
self.assertEqual(result, expected)
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])

View File

@ -6,7 +6,7 @@ from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.maintenance.temporal_operations import (
invalidate_edges,
)
@ -58,7 +58,7 @@ def create_test_data():
content='Alice now dislikes Bob',
created_at=now,
valid_at=now,
source='test',
source=EpisodeType.message,
source_description='Test episode for unit testing',
)
@ -69,7 +69,7 @@ def create_test_data():
content='Alice liked Bob',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source='test',
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
)
]