mirror of
https://github.com/getzep/graphiti.git
synced 2025-12-28 07:33:30 +00:00
Controlled example (#37)
* chore: Add romeo runner * fix: Linter * dedupe fixes * wip * wip dump * allbirds * chore: Update romeo parser * chore: Anthropic model fix * allbirds runner * format * wip * mypy updates * update * remove r * update tests * format * wip * wip * wip * chore: Strategically update the message * chore: Add romeo runner * fix: Linter * wip * wip dump * chore: Update romeo parser * chore: Anthropic model fix * wip * allbirds * allbirds runner * format * wip * wip * mypy updates * update * remove r * update tests * format * wip * chore: Strategically update the message * rebase and fix import issues * Update package imports for graphiti_core in examples and utils * nits * chore: Update OpenAI GPT-4o model to gpt-4o-2024-08-06 * implement groq * improvments & linting * cleanup and nits * Refactor package imports for graphiti_core in examples and utils * Refactor package imports for graphiti_core in examples and utils * chore: Nuke unused examples * chore: Nuke unused examples * chore: Only run type check on graphiti_core * fix unit tests * reformat * unit test * fix: Unit tests * test: Add coverage for extract_date_strings_from_edge * lint * remove commented code --------- Co-authored-by: prestonrasmussen <prasmuss15@gmail.com> Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
This commit is contained in:
parent
c5e52153c4
commit
0ed7739bc0
2
.github/workflows/typecheck.yml
vendored
2
.github/workflows/typecheck.yml
vendored
@ -37,7 +37,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
set -o pipefail
|
||||
poetry run mypy . --show-column-numbers --show-error-codes | sed -E '
|
||||
poetry run mypy ./graphiti_core --show-column-numbers --show-error-codes | sed -E '
|
||||
s/^(.*):([0-9]+):([0-9]+): (error|warning): (.+) \[(.+)\]/::error file=\1,line=\2,endLine=\2,col=\3,title=\6::\5/;
|
||||
s/^(.*):([0-9]+):([0-9]+): note: (.+)/::notice file=\1,line=\2,endLine=\2,col=\3,title=Note::\4/;
|
||||
'
|
||||
|
||||
1226
examples/ecommerce/allbirds_products.json
Normal file
1226
examples/ecommerce/allbirds_products.json
Normal file
File diff suppressed because it is too large
Load Diff
117
examples/ecommerce/runner.py
Normal file
117
examples/ecommerce/runner.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
from graphiti_core.utils.bulk_utils import RawEpisode
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
|
||||
load_dotenv()
|
||||
|
||||
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
shoe_conversation = [
|
||||
"SalesBot: Hi, I'm Allbirds Assistant! How can I help you today?",
|
||||
"John: Hi, I'm looking for a new pair of shoes.",
|
||||
'SalesBot: Of course! What kinde of material are you looking for?',
|
||||
"John: I'm looking for shoes made out of wool",
|
||||
"""SalesBot: We have just what you are looking for, how do you like our Men's SuperLight Wool Runners
|
||||
- Dark Grey (Medium Grey Sole)? They use the SuperLight Foam technology.""",
|
||||
"""John: Oh, actually I bought those 2 months ago, but unfortunately found out that I was allergic to wool.
|
||||
I think I will pass on those, maybe there is something with a retro look that you could suggest?""",
|
||||
"""SalesBot: Im sorry to hear that! Would you be interested in Men's Couriers -
|
||||
(Blizzard Sole) model? We have them in Natural Black and Basin Blue colors""",
|
||||
'John: Oh that is perfect, I LOVE the Natural Black color!. I will take those.',
|
||||
]
|
||||
|
||||
|
||||
async def add_messages(client: Graphiti):
|
||||
for i, message in enumerate(shoe_conversation):
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
episode_body=message,
|
||||
source=EpisodeType.message,
|
||||
reference_time=datetime.now(),
|
||||
source_description='Shoe conversation',
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
await clear_data(client.driver)
|
||||
await client.build_indices_and_constraints()
|
||||
await ingest_products_data(client)
|
||||
await add_messages(client)
|
||||
|
||||
|
||||
async def ingest_products_data(client: Graphiti):
|
||||
script_dir = Path(__file__).parent
|
||||
json_file_path = script_dir / 'allbirds_products.json'
|
||||
|
||||
with open(json_file_path) as file:
|
||||
products = json.load(file)['products']
|
||||
|
||||
episodes: list[RawEpisode] = [
|
||||
RawEpisode(
|
||||
name=f'Product {i}',
|
||||
content=str(product),
|
||||
source_description='Allbirds products',
|
||||
source=EpisodeType.json,
|
||||
reference_time=datetime.now(),
|
||||
)
|
||||
for i, product in enumerate(products)
|
||||
]
|
||||
|
||||
await client.add_episode_bulk(episodes)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
@ -23,7 +23,8 @@ from dotenv import load_dotenv
|
||||
from transcript_parser import parse_podcast_messages
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.utils.bulk_utils import BulkEpisode
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
from graphiti_core.utils.bulk_utils import RawEpisode
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
|
||||
load_dotenv()
|
||||
@ -70,12 +71,12 @@ async def main(use_bulk: bool = True):
|
||||
source_description='Podcast Transcript',
|
||||
)
|
||||
|
||||
episodes: list[BulkEpisode] = [
|
||||
BulkEpisode(
|
||||
episodes: list[RawEpisode] = [
|
||||
RawEpisode(
|
||||
name=f'Message {i}',
|
||||
content=f'{message.speaker_name} ({message.role}): {message.content}',
|
||||
source=EpisodeType.message,
|
||||
source_description='Podcast Transcript',
|
||||
episode_type='string',
|
||||
reference_time=message.actual_timestamp,
|
||||
)
|
||||
for i, message in enumerate(messages[3:14])
|
||||
|
||||
36
examples/wizard_of_oz/parser.py
Normal file
36
examples/wizard_of_oz/parser.py
Normal file
@ -0,0 +1,36 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def parse_wizard_of_oz(file_path):
|
||||
with open(file_path, encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
|
||||
# Split the content into chapters
|
||||
chapters = re.split(r'\n\n+Chapter [IVX]+\n', content)[
|
||||
1:
|
||||
] # Skip the first split which is before Chapter I
|
||||
|
||||
episodes = []
|
||||
for i, chapter in enumerate(chapters, start=1):
|
||||
# Extract chapter title
|
||||
title_match = re.match(r'(.*?)\n\n', chapter)
|
||||
title = title_match.group(1) if title_match else f'Chapter {i}'
|
||||
|
||||
# Remove the title from the chapter content
|
||||
chapter_content = chapter[len(title) :].strip() if title_match else chapter.strip()
|
||||
|
||||
# Create episode dictionary
|
||||
episode = {'episode_number': i, 'title': title, 'content': chapter_content}
|
||||
episodes.append(episode)
|
||||
|
||||
return episodes
|
||||
|
||||
|
||||
def get_wizard_of_oz_messages():
|
||||
file_path = 'woo.txt'
|
||||
script_dir = os.path.dirname(__file__)
|
||||
relative_path = os.path.join(script_dir, file_path)
|
||||
# Use the function
|
||||
parsed_episodes = parse_wizard_of_oz(relative_path)
|
||||
return parsed_episodes
|
||||
93
examples/wizard_of_oz/runner.py
Normal file
93
examples/wizard_of_oz/runner.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from examples.wizard_of_oz.parser import get_wizard_of_oz_messages
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.llm_client.anthropic_client import AnthropicClient
|
||||
from graphiti_core.llm_client.config import LLMConfig
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
|
||||
load_dotenv()
|
||||
|
||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
llm_client = AnthropicClient(LLMConfig(api_key=os.environ.get('ANTHROPIC_API_KEY')))
|
||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password, llm_client)
|
||||
messages = get_wizard_of_oz_messages()
|
||||
print(messages)
|
||||
print(len(messages))
|
||||
now = datetime.now()
|
||||
# episodes: list[BulkEpisode] = [
|
||||
# BulkEpisode(
|
||||
# name=f'Chapter {i + 1}',
|
||||
# content=chapter['content'],
|
||||
# source_description='Wizard of Oz Transcript',
|
||||
# episode_type='string',
|
||||
# reference_time=now + timedelta(seconds=i * 10),
|
||||
# )
|
||||
# for i, chapter in enumerate(messages[0:50])
|
||||
# ]
|
||||
|
||||
# await clear_data(client.driver)
|
||||
# await client.build_indices_and_constraints()
|
||||
# await client.add_episode_bulk(episodes)
|
||||
|
||||
await clear_data(client.driver)
|
||||
await client.build_indices_and_constraints()
|
||||
for i, chapter in enumerate(messages):
|
||||
await client.add_episode(
|
||||
name=f'Chapter {i + 1}',
|
||||
episode_body=chapter['content'],
|
||||
source_description='Wizard of Oz Transcript',
|
||||
reference_time=now + timedelta(seconds=i * 10),
|
||||
)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
4671
examples/wizard_of_oz/woo.txt
Normal file
4671
examples/wizard_of_oz/woo.txt
Normal file
File diff suppressed because it is too large
Load Diff
@ -26,7 +26,7 @@ from neo4j import AsyncGraphDatabase
|
||||
|
||||
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||
from graphiti_core.llm_client import LLMClient, LLMConfig, OpenAIClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search import SearchConfig, hybrid_search
|
||||
from graphiti_core.search.search_utils import (
|
||||
get_relevant_edges,
|
||||
@ -37,7 +37,7 @@ from graphiti_core.utils import (
|
||||
retrieve_episodes,
|
||||
)
|
||||
from graphiti_core.utils.bulk_utils import (
|
||||
BulkEpisode,
|
||||
RawEpisode,
|
||||
dedupe_edges_bulk,
|
||||
dedupe_nodes_bulk,
|
||||
extract_nodes_and_edges_bulk,
|
||||
@ -74,8 +74,7 @@ class Graphiti:
|
||||
self.llm_client = OpenAIClient(
|
||||
LLMConfig(
|
||||
api_key=os.getenv('OPENAI_API_KEY', default=''),
|
||||
model='gpt-4o-mini',
|
||||
base_url='https://api.openai.com/v1',
|
||||
model='gpt-4o-2024-08-06',
|
||||
)
|
||||
)
|
||||
|
||||
@ -99,6 +98,7 @@ class Graphiti:
|
||||
episode_body: str,
|
||||
source_description: str,
|
||||
reference_time: datetime,
|
||||
source: EpisodeType = EpisodeType.message,
|
||||
success_callback: Callable | None = None,
|
||||
error_callback: Callable | None = None,
|
||||
):
|
||||
@ -112,11 +112,11 @@ class Graphiti:
|
||||
embedder = self.llm_client.get_embedder()
|
||||
now = datetime.now()
|
||||
|
||||
previous_episodes = await self.retrieve_episodes(reference_time)
|
||||
previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
|
||||
episode = EpisodicNode(
|
||||
name=name,
|
||||
labels=[],
|
||||
source='messages',
|
||||
source=source,
|
||||
content=episode_body,
|
||||
source_description=source_description,
|
||||
created_at=now,
|
||||
@ -149,12 +149,6 @@ class Graphiti:
|
||||
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
|
||||
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
|
||||
|
||||
# deduped_edges = await dedupe_extracted_edges_v2(
|
||||
# self.llm_client,
|
||||
# extract_node_and_edge_triplets(extracted_edges, nodes),
|
||||
# extract_node_and_edge_triplets(existing_edges, nodes),
|
||||
# )
|
||||
|
||||
deduped_edges = await dedupe_extracted_edges(
|
||||
self.llm_client,
|
||||
extracted_edges,
|
||||
@ -166,6 +160,30 @@ class Graphiti:
|
||||
edge_touched_node_uuids.append(edge.source_node_uuid)
|
||||
edge_touched_node_uuids.append(edge.target_node_uuid)
|
||||
|
||||
for edge in deduped_edges:
|
||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||
self.llm_client,
|
||||
edge,
|
||||
episode.valid_at,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at:
|
||||
edge.expired_at = datetime.now()
|
||||
for edge in existing_edges:
|
||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||
self.llm_client,
|
||||
edge,
|
||||
episode.valid_at,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
edge.valid_at = valid_at
|
||||
edge.invalid_at = invalid_at
|
||||
if edge.invalid_at:
|
||||
edge.expired_at = datetime.now()
|
||||
(
|
||||
old_edges_with_nodes_pending_invalidation,
|
||||
new_edges_with_nodes,
|
||||
@ -190,19 +208,10 @@ class Graphiti:
|
||||
deduped_edge.expired_at = edge.expired_at
|
||||
edge_touched_node_uuids.append(edge.source_node_uuid)
|
||||
edge_touched_node_uuids.append(edge.target_node_uuid)
|
||||
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
|
||||
|
||||
edges_to_save = existing_edges + deduped_edges
|
||||
|
||||
for edge_to_extract_dates_from in edges_to_save:
|
||||
valid_at, invalid_at, _ = await extract_edge_dates(
|
||||
self.llm_client,
|
||||
edge_to_extract_dates_from,
|
||||
episode.valid_at,
|
||||
episode,
|
||||
previous_episodes,
|
||||
)
|
||||
edge_to_extract_dates_from.valid_at = valid_at
|
||||
edge_to_extract_dates_from.invalid_at = invalid_at
|
||||
entity_edges.extend(edges_to_save)
|
||||
|
||||
edge_touched_node_uuids = list(set(edge_touched_node_uuids))
|
||||
@ -210,8 +219,6 @@ class Graphiti:
|
||||
|
||||
logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
|
||||
|
||||
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
|
||||
|
||||
logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
|
||||
|
||||
episodic_edges.extend(
|
||||
@ -232,7 +239,7 @@ class Graphiti:
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
|
||||
|
||||
end = time()
|
||||
logger.info(f'Completed add_episode in {(end-start) * 1000} ms')
|
||||
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
||||
# for node in nodes:
|
||||
# if isinstance(node, EntityNode):
|
||||
# await node.update_summary(self.driver)
|
||||
@ -246,7 +253,7 @@ class Graphiti:
|
||||
|
||||
async def add_episode_bulk(
|
||||
self,
|
||||
bulk_episodes: list[BulkEpisode],
|
||||
bulk_episodes: list[RawEpisode],
|
||||
):
|
||||
try:
|
||||
start = time()
|
||||
@ -257,7 +264,7 @@ class Graphiti:
|
||||
EpisodicNode(
|
||||
name=episode.name,
|
||||
labels=[],
|
||||
source='messages',
|
||||
source=episode.source,
|
||||
content=episode.content,
|
||||
source_description=episode.source_description,
|
||||
created_at=now,
|
||||
@ -316,7 +323,7 @@ class Graphiti:
|
||||
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
|
||||
|
||||
end = time()
|
||||
logger.info(f'Completed add_episode_bulk in {(end-start) * 1000} ms')
|
||||
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
from .groq_client import GroqClient
|
||||
from .openai_client import OpenAIClient
|
||||
|
||||
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
|
||||
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'AnthropicClient', 'GroqClient']
|
||||
|
||||
60
graphiti_core/llm_client/anthropic_client.py
Normal file
60
graphiti_core/llm_client/anthropic_client.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicClient(LLMClient):
|
||||
def __init__(self, config: LLMConfig | None = None):
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
self.client = AsyncAnthropic(api_key=config.api_key)
|
||||
self.model = config.model
|
||||
|
||||
def get_embedder(self) -> typing.Any:
|
||||
openai_client = AsyncOpenAI()
|
||||
return openai_client.embeddings
|
||||
|
||||
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
system_message = messages[0]
|
||||
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
|
||||
{'role': 'assistant', 'content': '{'}
|
||||
]
|
||||
|
||||
try:
|
||||
result = await self.client.messages.create(
|
||||
system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
|
||||
+ system_message.content,
|
||||
max_tokens=4096,
|
||||
messages=user_messages, # type: ignore
|
||||
model='claude-3-5-sonnet-20240620',
|
||||
)
|
||||
|
||||
return json.loads('{' + result.content[0].text) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
@ -23,7 +23,7 @@ from .config import LLMConfig
|
||||
|
||||
class LLMClient(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, config: LLMConfig):
|
||||
def __init__(self, config: LLMConfig | None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -28,9 +28,9 @@ class LLMConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
api_key: str | None = None,
|
||||
model: str = 'gpt-4o-mini',
|
||||
base_url: str = 'https://api.openai.com',
|
||||
base_url: str = 'https://api.openai.com/v1',
|
||||
):
|
||||
"""
|
||||
Initialize the LLMConfig with the provided parameters.
|
||||
|
||||
63
graphiti_core/llm_client/groq_client.py
Normal file
63
graphiti_core/llm_client/groq_client.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from groq import AsyncGroq
|
||||
from groq.types.chat import ChatCompletionMessageParam
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GroqClient(LLMClient):
|
||||
def __init__(self, config: LLMConfig | None = None):
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
self.client = AsyncGroq(api_key=config.api_key)
|
||||
self.model = config.model
|
||||
|
||||
def get_embedder(self) -> typing.Any:
|
||||
openai_client = AsyncOpenAI()
|
||||
return openai_client.embeddings
|
||||
|
||||
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
openai_messages: list[ChatCompletionMessageParam] = []
|
||||
for m in messages:
|
||||
if m.role == 'user':
|
||||
openai_messages.append({'role': 'user', 'content': m.content})
|
||||
elif m.role == 'system':
|
||||
openai_messages.append({'role': 'system', 'content': m.content})
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model='llama-3.1-70b-versatile',
|
||||
messages=openai_messages,
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
response_format={'type': 'json_object'},
|
||||
)
|
||||
result = response.choices[0].message.content or ''
|
||||
return json.loads(result)
|
||||
except Exception as e:
|
||||
print(openai_messages)
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
@ -29,7 +29,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
def __init__(self, config: LLMConfig):
|
||||
def __init__(self, config: LLMConfig | None = None):
|
||||
if config is None:
|
||||
config = LLMConfig()
|
||||
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
self.model = config.model
|
||||
|
||||
@ -48,11 +50,12 @@ class OpenAIClient(LLMClient):
|
||||
model=self.model,
|
||||
messages=openai_messages,
|
||||
temperature=0,
|
||||
max_tokens=3000,
|
||||
max_tokens=4096,
|
||||
response_format={'type': 'json_object'},
|
||||
)
|
||||
result = response.choices[0].message.content or ''
|
||||
return json.loads(result)
|
||||
except Exception as e:
|
||||
print(openai_messages)
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
||||
@ -17,6 +17,7 @@ limitations under the License.
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from time import time
|
||||
from uuid import uuid4
|
||||
|
||||
@ -29,6 +30,20 @@ from graphiti_core.llm_client.config import EMBEDDING_DIM
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodeType(Enum):
|
||||
message = 'message'
|
||||
json = 'json'
|
||||
|
||||
@staticmethod
|
||||
def from_str(episode_type: str):
|
||||
if episode_type == 'message':
|
||||
return EpisodeType.message
|
||||
if episode_type == 'json':
|
||||
return EpisodeType.json
|
||||
logger.error(f'Episode type: {episode_type} not implemented')
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Node(BaseModel, ABC):
|
||||
uuid: str = Field(default_factory=lambda: uuid4().hex)
|
||||
name: str
|
||||
@ -48,7 +63,7 @@ class Node(BaseModel, ABC):
|
||||
|
||||
|
||||
class EpisodicNode(Node):
|
||||
source: str = Field(description='source type')
|
||||
source: EpisodeType = Field(description='source type')
|
||||
source_description: str = Field(description='description of the data source')
|
||||
content: str = Field(description='raw episode data')
|
||||
valid_at: datetime = Field(
|
||||
@ -73,7 +88,7 @@ class EpisodicNode(Node):
|
||||
entity_edges=self.entity_edges,
|
||||
created_at=self.created_at,
|
||||
valid_at=self.valid_at,
|
||||
source=self.source,
|
||||
source=self.source.value,
|
||||
_database='neo4j',
|
||||
)
|
||||
|
||||
@ -96,7 +111,7 @@ class EntityNode(Node):
|
||||
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
|
||||
self.name_embedding = embedding[:EMBEDDING_DIM]
|
||||
end = time()
|
||||
logger.info(f'embedded {text} in {end-start} ms')
|
||||
logger.info(f'embedded {text} in {end - start} ms')
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
@ -88,7 +88,8 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||
|
||||
New Nodes:
|
||||
{json.dumps(context['extracted_nodes'], indent=2)}
|
||||
|
||||
Important:
|
||||
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
|
||||
Task:
|
||||
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||
Reference Timestamp: {context['reference_timestamp']}
|
||||
|
||||
IMPORTANT: Only extract time information if it is part of the provided fact. Otherwise ignore the time mentioned. Make sure to do your best to determine the dates if only the relative time is mentioned. (eg 10 years ago, 2 mins ago) based on the provided reference timestamp
|
||||
|
||||
If the relationship is not of spanning nature, but you are still able to determine the dates, set the valid_at only.
|
||||
Definitions:
|
||||
- valid_at: The date and time when the relationship described by the edge fact became true or was established.
|
||||
- invalid_at: The date and time when the relationship described by the edge fact stopped being true or ended.
|
||||
|
||||
@ -23,64 +23,16 @@ from .models import Message, PromptFunction, PromptVersion
|
||||
class Prompt(Protocol):
|
||||
v1: PromptVersion
|
||||
v2: PromptVersion
|
||||
v3: PromptVersion
|
||||
extract_json: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
v1: PromptFunction
|
||||
v2: PromptFunction
|
||||
v3: PromptFunction
|
||||
extract_json: PromptFunction
|
||||
|
||||
|
||||
def v1(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content='You are a helpful assistant that extracts graph nodes from provided context.',
|
||||
),
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Given the following context, extract new semantic nodes that need to be added to the knowledge graph:
|
||||
|
||||
Existing Nodes:
|
||||
{json.dumps(context['existing_nodes'], indent=2)}
|
||||
|
||||
Previous Episodes:
|
||||
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
|
||||
|
||||
New Episode:
|
||||
Content: {context["episode_content"]}
|
||||
Timestamp: {context['episode_timestamp']}
|
||||
|
||||
Extract new semantic nodes based on the content of the current episode, while considering the existing nodes and context from previous episodes.
|
||||
|
||||
Guidelines:
|
||||
1. Only extract new nodes that don't already exist in the graph structure.
|
||||
2. Focus on entities, concepts, or actors that are central to the current episode.
|
||||
3. Avoid creating nodes for relationships or actions (these will be handled as edges later).
|
||||
4. Provide a brief but informative summary for each node.
|
||||
5. If a node seems to represent an existing concept but with updated information, don't create a new node. This will be handled by edge updates.
|
||||
6. Do not create nodes for episodic content (like Message 1 or Message 2).
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"new_nodes": [
|
||||
{{
|
||||
"name": "Unique identifier for the node",
|
||||
"labels": ["Semantic", "OptionalAdditionalLabel"],
|
||||
"summary": "Brief summary of the node's role or significance"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
If no new nodes need to be added, return an empty list for "new_nodes".
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def v2(context: dict[str, Any]) -> list[Message]:
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
@ -121,7 +73,7 @@ def v2(context: dict[str, Any]) -> list[Message]:
|
||||
]
|
||||
|
||||
|
||||
def v3(context: dict[str, Any]) -> list[Message]:
|
||||
def v2(context: dict[str, Any]) -> list[Message]:
|
||||
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
|
||||
|
||||
user_prompt = f"""
|
||||
@ -141,7 +93,7 @@ Guidelines:
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"new_nodes": [
|
||||
"extracted_nodes": [
|
||||
{{
|
||||
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
|
||||
"labels": ["Entity", "Speaker" for speaker nodes, "OptionalAdditionalLabel"],
|
||||
@ -156,4 +108,38 @@ Respond with a JSON object in the following format:
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3}
|
||||
def extract_json(context: dict[str, Any]) -> list[Message]:
|
||||
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text.
|
||||
Your primary task is to identify and extract relevant entities from JSON files"""
|
||||
|
||||
user_prompt = f"""
|
||||
Given the following source description, extract relevant entity nodes from the provided JSON:
|
||||
|
||||
Source Description:
|
||||
{context["source_description"]}
|
||||
|
||||
JSON:
|
||||
{context["episode_content"]}
|
||||
|
||||
Guidelines:
|
||||
1. Always try to extract an entities that the JSON represents. This will often be something like a "name" or "user field
|
||||
2. Do NOT extract any properties that contain dates
|
||||
|
||||
Respond with a JSON object in the following format:
|
||||
{{
|
||||
"extracted_nodes": [
|
||||
{{
|
||||
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
|
||||
"labels": ["Entity", "Speaker" for speaker nodes, "OptionalAdditionalLabel"],
|
||||
"summary": "Brief summary of the node's role or significance"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
return [
|
||||
Message(role='system', content=sys_prompt),
|
||||
Message(role='user', content=user_prompt),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {'v1': v1, 'v2': v2, 'extract_json': extract_json}
|
||||
|
||||
@ -36,9 +36,10 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||
Message(
|
||||
role='user',
|
||||
content=f"""
|
||||
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
|
||||
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
|
||||
Do not invalidate relationships merely because they weren't mentioned in new edges. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships.
|
||||
Based on the provided existing edges and new edges with their timestamps, determine which relationships, if any, should be marked as expired due to contradictions or updates in the newer edges.
|
||||
Use the start and end dates of the edges to determine which edges are to be marked expired.
|
||||
Only mark a relationship as invalid if there is clear evidence from other edges that the relationship is no longer true.
|
||||
Do not invalidate relationships merely because they weren't mentioned in the episodes. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships.
|
||||
|
||||
Previous Episodes:
|
||||
{context['previous_episodes']}
|
||||
@ -52,7 +53,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
|
||||
New Edges:
|
||||
{context['new_edges']}
|
||||
|
||||
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), TIMESTAMP)"
|
||||
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), START_DATE (END_DATE, optional))"
|
||||
|
||||
For each existing edge that should be invalidated, respond with a JSON object in the following format:
|
||||
{{
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
@ -186,7 +187,7 @@ async def entity_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityNode]:
|
||||
# BM25 search to get top nodes
|
||||
fuzzy_query = query + '~'
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
||||
@ -221,7 +222,7 @@ async def edge_fulltext_search(
|
||||
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
||||
) -> list[EntityEdge]:
|
||||
# fulltext search over facts
|
||||
fuzzy_query = query + '~'
|
||||
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
||||
|
||||
records, _, _ = await driver.execute_query(
|
||||
"""
|
||||
|
||||
@ -24,7 +24,7 @@ from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
|
||||
from graphiti_core.utils import retrieve_episodes
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
@ -43,11 +43,11 @@ from graphiti_core.utils.maintenance.node_operations import (
|
||||
CHUNK_SIZE = 15
|
||||
|
||||
|
||||
class BulkEpisode(BaseModel):
|
||||
class RawEpisode(BaseModel):
|
||||
name: str
|
||||
content: str
|
||||
source_description: str
|
||||
episode_type: str
|
||||
source: EpisodeType
|
||||
reference_time: datetime
|
||||
|
||||
|
||||
|
||||
@ -70,6 +70,7 @@ async def extract_edges(
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.extract_edges.v2(context))
|
||||
print(llm_response)
|
||||
edges_data = llm_response.get('edges', [])
|
||||
|
||||
end = time()
|
||||
|
||||
@ -21,7 +21,7 @@ from datetime import datetime, timezone
|
||||
from neo4j import AsyncDriver
|
||||
from typing_extensions import LiteralString
|
||||
|
||||
from graphiti_core.nodes import EpisodicNode
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
|
||||
EPISODE_WINDOW_LEN = 3
|
||||
|
||||
@ -112,7 +112,7 @@ async def retrieve_episodes(
|
||||
),
|
||||
valid_at=(record['valid_at'].to_native()),
|
||||
uuid=record['uuid'],
|
||||
source=record['source'],
|
||||
source=EpisodeType.from_str(record['source']),
|
||||
name=record['name'],
|
||||
source_description=record['source_description'],
|
||||
)
|
||||
|
||||
@ -17,42 +17,71 @@ limitations under the License.
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def extract_message_nodes(
|
||||
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
|
||||
) -> list[dict[str, Any]]:
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
'episode_content': episode.content,
|
||||
'episode_timestamp': episode.valid_at.isoformat(),
|
||||
'previous_episodes': [
|
||||
{
|
||||
'content': ep.content,
|
||||
'timestamp': ep.valid_at.isoformat(),
|
||||
}
|
||||
for ep in previous_episodes
|
||||
],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v2(context))
|
||||
extracted_node_data = llm_response.get('extracted_nodes', [])
|
||||
return extracted_node_data
|
||||
|
||||
|
||||
async def extract_json_nodes(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
) -> list[dict[str, Any]]:
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
'episode_content': episode.content,
|
||||
'episode_timestamp': episode.valid_at.isoformat(),
|
||||
'source_description': episode.source_description,
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.extract_nodes.extract_json(context)
|
||||
)
|
||||
extracted_node_data = llm_response.get('extracted_nodes', [])
|
||||
return extracted_node_data
|
||||
|
||||
|
||||
async def extract_nodes(
|
||||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
|
||||
# Prepare context for LLM
|
||||
context = {
|
||||
'episode_content': episode.content,
|
||||
'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None),
|
||||
'previous_episodes': [
|
||||
{
|
||||
'content': ep.content,
|
||||
'timestamp': ep.valid_at.isoformat() if ep.valid_at else None,
|
||||
}
|
||||
for ep in previous_episodes
|
||||
],
|
||||
}
|
||||
|
||||
llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v3(context))
|
||||
new_nodes_data = llm_response.get('new_nodes', [])
|
||||
extracted_node_data: list[dict[str, Any]] = []
|
||||
if episode.source == EpisodeType.message:
|
||||
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
|
||||
elif episode.source == EpisodeType.json:
|
||||
extracted_node_data = await extract_json_nodes(llm_client, episode)
|
||||
|
||||
end = time()
|
||||
logger.info(f'Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms')
|
||||
logger.info(f'Extracted new nodes: {extracted_node_data} in {(end - start) * 1000} ms')
|
||||
# Convert the extracted data into EntityNode objects
|
||||
new_nodes = []
|
||||
for node_data in new_nodes_data:
|
||||
for node_data in extracted_node_data:
|
||||
new_node = EntityNode(
|
||||
name=node_data['name'],
|
||||
labels=node_data['labels'],
|
||||
|
||||
@ -91,6 +91,15 @@ async def invalidate_edges(
|
||||
return invalidated_edges
|
||||
|
||||
|
||||
def extract_date_strings_from_edge(edge: EntityEdge) -> str:
|
||||
start = edge.valid_at
|
||||
end = edge.invalid_at
|
||||
date_string = f'Start Date: {start.isoformat()}' if start else ''
|
||||
if end:
|
||||
date_string += f' (End Date: {end.isoformat()})'
|
||||
return date_string
|
||||
|
||||
|
||||
def prepare_invalidation_context(
|
||||
existing_edges: list[NodeEdgeNodeTriplet],
|
||||
new_edges: list[NodeEdgeNodeTriplet],
|
||||
@ -99,15 +108,15 @@ def prepare_invalidation_context(
|
||||
) -> dict:
|
||||
return {
|
||||
'existing_edges': [
|
||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})'
|
||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
|
||||
for source_node, edge, target_node in sorted(
|
||||
existing_edges, key=lambda x: x[1].created_at, reverse=True
|
||||
existing_edges, key=lambda x: (x[1].created_at), reverse=True
|
||||
)
|
||||
],
|
||||
'new_edges': [
|
||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})'
|
||||
f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) {extract_date_strings_from_edge(edge)}'
|
||||
for source_node, edge, target_node in sorted(
|
||||
new_edges, key=lambda x: x[1].created_at, reverse=True
|
||||
new_edges, key=lambda x: (x[1].created_at), reverse=True
|
||||
)
|
||||
],
|
||||
'current_episode': current_episode.content,
|
||||
|
||||
1613
poetry.lock
generated
1613
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -22,6 +22,7 @@ sentence-transformers = "^3.0.1"
|
||||
diskcache = "^5.6.3"
|
||||
arrow = "^1.3.0"
|
||||
openai = "^1.38.0"
|
||||
anthropic = "^0.34.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^8.3.2"
|
||||
@ -33,6 +34,9 @@ ruff = "^0.6.2"
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pydantic = "^2.8.2"
|
||||
mypy = "^1.11.1"
|
||||
groq = "^0.9.0"
|
||||
ipykernel = "^6.29.5"
|
||||
jupyterlab = "^4.2.4"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
131
runner.py
131
runner.py
@ -1,131 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
|
||||
load_dotenv()
|
||||
|
||||
neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687'
|
||||
neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j'
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password'
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
bmw_sales = [
|
||||
{
|
||||
'episode_body': 'Paul (buyer): Hi, I would like to buy a new car',
|
||||
},
|
||||
{
|
||||
'episode_body': 'Dan The Salesman (salesman): Sure, I can help you with that. What kind of car are you looking for?',
|
||||
},
|
||||
{
|
||||
'episode_body': 'Paul (buyer): I am looking for a new BMW',
|
||||
},
|
||||
{
|
||||
'episode_body': 'Dan The Salesman (salesman): Great choice! What kind of BMW are you looking for?',
|
||||
},
|
||||
{
|
||||
'episode_body': 'Paul (buyer): I am considering a BMW 3 series',
|
||||
},
|
||||
{
|
||||
'episode_body': 'Dan The Salesman (salesman): Great choice, we currently have a 2024 BMW 3 series in stock, it is a great car and costs $50,000',
|
||||
},
|
||||
{
|
||||
'episode_body': "Paul (buyer): Actually I am interested in something cheaper, I won't consider anything over $30,000",
|
||||
},
|
||||
]
|
||||
|
||||
dates_mentioned = [
|
||||
{
|
||||
'episode_body': 'Paul (user): I have graduated from Univerity of Toronto in 2022',
|
||||
},
|
||||
{
|
||||
'episode_body': 'Jane (user): How cool, I graduated from the same school in 1999',
|
||||
},
|
||||
]
|
||||
|
||||
times_mentioned = [
|
||||
{
|
||||
'episode_body': 'Paul (user): 15 minutes ago we put a deposit on our new house',
|
||||
},
|
||||
]
|
||||
|
||||
time_range_mentioned = [
|
||||
{
|
||||
'episode_body': 'Paul (user): I served as a US Marine in 2015-2019',
|
||||
},
|
||||
]
|
||||
|
||||
relative_time_range_mentioned = [
|
||||
{
|
||||
'episode_body': 'Paul (user): I lived in Toronto for 10 years, until moving to Vancouver yesterday',
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
setup_logging()
|
||||
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
await clear_data(client.driver)
|
||||
await client.build_indices_and_constraints()
|
||||
|
||||
for i, message in enumerate(bmw_sales):
|
||||
await client.add_episode(
|
||||
name=f'Message {i}',
|
||||
episode_body=message['episode_body'],
|
||||
source_description='',
|
||||
# reference_time=datetime.now() - timedelta(days=365 * 3),
|
||||
reference_time=datetime.now(),
|
||||
)
|
||||
# await client.add_episode(
|
||||
# name='Message 5',
|
||||
# episode_body='Jane: I miss Paul',
|
||||
# source_description='WhatsApp Message',
|
||||
# reference_time=datetime.now(),
|
||||
# )
|
||||
# await client.add_episode(
|
||||
# name='Message 6',
|
||||
# episode_body='Jane: I dont miss Paul anymore, I hate him',
|
||||
# source_description='WhatsApp Message',
|
||||
# reference_time=datetime.now(),
|
||||
# )
|
||||
|
||||
# await client.add_episode(
|
||||
# name="Message 3",
|
||||
# episode_body="Assistant: The best type of apples available are Fuji apples",
|
||||
# source_description="WhatsApp Message",
|
||||
# )
|
||||
# await client.add_episode(
|
||||
# name="Message 4",
|
||||
# episode_body="Paul: Oh, I actually hate those",
|
||||
# source_description="WhatsApp Message",
|
||||
# )
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
@ -1,10 +1,12 @@
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
extract_date_strings_from_edge,
|
||||
prepare_edges_for_invalidation,
|
||||
prepare_invalidation_context,
|
||||
)
|
||||
@ -153,7 +155,7 @@ def test_prepare_invalidation_context():
|
||||
content='This is the current episode content.',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='test',
|
||||
source=EpisodeType.message,
|
||||
source_description='Test episode for unit testing',
|
||||
)
|
||||
previous_episodes = [
|
||||
@ -162,7 +164,7 @@ def test_prepare_invalidation_context():
|
||||
content='This is the content of previous episode 1.',
|
||||
created_at=now - timedelta(days=1),
|
||||
valid_at=now - timedelta(days=1),
|
||||
source='test',
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode 1 for unit testing',
|
||||
),
|
||||
EpisodicNode(
|
||||
@ -170,7 +172,7 @@ def test_prepare_invalidation_context():
|
||||
content='This is the content of previous episode 2.',
|
||||
created_at=now - timedelta(days=2),
|
||||
valid_at=now - timedelta(days=2),
|
||||
source='test',
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode 2 for unit testing',
|
||||
),
|
||||
]
|
||||
@ -197,7 +199,7 @@ def test_prepare_invalidation_context():
|
||||
assert node1.name in existing_edge_str
|
||||
assert edge1.name in existing_edge_str
|
||||
assert node2.name in existing_edge_str
|
||||
assert edge1.created_at.isoformat() in existing_edge_str
|
||||
assert edge1.fact in existing_edge_str
|
||||
|
||||
# Check the format of the new edge
|
||||
new_edge_str = result['new_edges'][0]
|
||||
@ -205,7 +207,7 @@ def test_prepare_invalidation_context():
|
||||
assert node2.name in new_edge_str
|
||||
assert edge2.name in new_edge_str
|
||||
assert node3.name in new_edge_str
|
||||
assert edge2.created_at.isoformat() in new_edge_str
|
||||
assert edge2.fact in new_edge_str
|
||||
|
||||
|
||||
def test_prepare_invalidation_context_empty_input():
|
||||
@ -215,7 +217,7 @@ def test_prepare_invalidation_context_empty_input():
|
||||
content='Empty episode',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='test',
|
||||
source=EpisodeType.message,
|
||||
source_description='Test empty episode for unit testing',
|
||||
)
|
||||
result = prepare_invalidation_context([], [], current_episode, [])
|
||||
@ -267,7 +269,7 @@ def test_prepare_invalidation_context_sorting():
|
||||
content='This is the current episode content.',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='test',
|
||||
source=EpisodeType.message,
|
||||
source_description='Test episode for unit testing',
|
||||
)
|
||||
previous_episodes = [
|
||||
@ -276,7 +278,7 @@ def test_prepare_invalidation_context_sorting():
|
||||
content='This is the content of a previous episode.',
|
||||
created_at=now - timedelta(days=1),
|
||||
valid_at=now - timedelta(days=1),
|
||||
source='test',
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode for unit testing',
|
||||
),
|
||||
]
|
||||
@ -293,6 +295,43 @@ def test_prepare_invalidation_context_sorting():
|
||||
assert result['previous_episodes'][0] == previous_episodes[0].content
|
||||
|
||||
|
||||
class TestExtractDateStringsFromEdge(unittest.TestCase):
|
||||
def generate_entity_edge(self, valid_at, invalid_at):
|
||||
return EntityEdge(
|
||||
source_node_uuid='1',
|
||||
target_node_uuid='2',
|
||||
name='KNOWS',
|
||||
fact='Node1 knows Node2',
|
||||
created_at=datetime.now(),
|
||||
valid_at=valid_at,
|
||||
invalid_at=invalid_at,
|
||||
)
|
||||
|
||||
def test_both_dates_present(self):
|
||||
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), datetime(2024, 1, 2, 12, 0))
|
||||
result = extract_date_strings_from_edge(edge)
|
||||
expected = 'Start Date: 2024-01-01T12:00:00 (End Date: 2024-01-02T12:00:00)'
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_only_valid_at_present(self):
|
||||
edge = self.generate_entity_edge(datetime(2024, 1, 1, 12, 0), None)
|
||||
result = extract_date_strings_from_edge(edge)
|
||||
expected = 'Start Date: 2024-01-01T12:00:00'
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_only_invalid_at_present(self):
|
||||
edge = self.generate_entity_edge(None, datetime(2024, 1, 2, 12, 0))
|
||||
result = extract_date_strings_from_edge(edge)
|
||||
expected = ' (End Date: 2024-01-02T12:00:00)'
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_no_dates_present(self):
|
||||
edge = self.generate_entity_edge(None, None)
|
||||
result = extract_date_strings_from_edge(edge)
|
||||
expected = ''
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
# Run the tests
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -6,7 +6,7 @@ from dotenv import load_dotenv
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.utils.maintenance.temporal_operations import (
|
||||
invalidate_edges,
|
||||
)
|
||||
@ -58,7 +58,7 @@ def create_test_data():
|
||||
content='Alice now dislikes Bob',
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
source='test',
|
||||
source=EpisodeType.message,
|
||||
source_description='Test episode for unit testing',
|
||||
)
|
||||
|
||||
@ -69,7 +69,7 @@ def create_test_data():
|
||||
content='Alice liked Bob',
|
||||
created_at=now - timedelta(days=1),
|
||||
valid_at=now - timedelta(days=1),
|
||||
source='test',
|
||||
source=EpisodeType.message,
|
||||
source_description='Test previous episode for unit testing',
|
||||
)
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user