mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
e2e graph builder eval (#343)
* add partial eval platform * dedupe updates * add e2e eval * Update graphiti_core/prompts/eval.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * clear all outputs * clear all outputs * squash eval commits * Update tests/evals/data/utils.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * add longmemeval disclaimer * remove gitignore * add copyright headers * add cli * Update tests/evals/data/longmemeval_data/README.md Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update tests/evals/eval_e2e_graph_building.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * updates --------- Co-authored-by: jackaldenryan <jackaldenryan@gmail.com> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
6aa25a1901
commit
5dce26722e
3
.gitignore
vendored
3
.gitignore
vendored
@ -164,3 +164,6 @@ cython_debug/
|
|||||||
## Other
|
## Other
|
||||||
# Cache files
|
# Cache files
|
||||||
cache.db*
|
cache.db*
|
||||||
|
|
||||||
|
# All DS_Store files
|
||||||
|
.DS_Store
|
@ -37,16 +37,28 @@ class EvalResponse(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EvalAddEpisodeResults(BaseModel):
|
||||||
|
candidate_is_worse: bool = Field(
|
||||||
|
...,
|
||||||
|
description='boolean if the baseline extraction is higher quality than the candidate extraction.',
|
||||||
|
)
|
||||||
|
reasoning: str = Field(
|
||||||
|
..., description='why you determined the response was correct or incorrect'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Prompt(Protocol):
|
class Prompt(Protocol):
|
||||||
qa_prompt: PromptVersion
|
qa_prompt: PromptVersion
|
||||||
eval_prompt: PromptVersion
|
eval_prompt: PromptVersion
|
||||||
query_expansion: PromptVersion
|
query_expansion: PromptVersion
|
||||||
|
eval_add_episode_results: PromptVersion
|
||||||
|
|
||||||
|
|
||||||
class Versions(TypedDict):
|
class Versions(TypedDict):
|
||||||
qa_prompt: PromptFunction
|
qa_prompt: PromptFunction
|
||||||
eval_prompt: PromptFunction
|
eval_prompt: PromptFunction
|
||||||
query_expansion: PromptFunction
|
query_expansion: PromptFunction
|
||||||
|
eval_add_episode_results: PromptFunction
|
||||||
|
|
||||||
|
|
||||||
def query_expansion(context: dict[str, Any]) -> list[Message]:
|
def query_expansion(context: dict[str, Any]) -> list[Message]:
|
||||||
@ -112,8 +124,41 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def eval_add_episode_results(context: dict[str, Any]) -> list[Message]:
|
||||||
|
sys_prompt = """You are a judge that determines whether a baseline graph building result from a list of messages is better
|
||||||
|
than a candidate graph building result based on the same messages."""
|
||||||
|
|
||||||
|
user_prompt = f"""
|
||||||
|
Given the following PREVIOUS MESSAGES and MESSAGE, determine if the BASELINE graph data extracted from the
|
||||||
|
conversation is higher quality than the CANDIDATE graph data extracted from the conversation.
|
||||||
|
|
||||||
|
Return False if the BASELINE extraction is better, and True otherwise. If the CANDIDATE extraction and
|
||||||
|
BASELINE extraction are nearly identical in quality, return True. Add your reasoning for your decision to the reasoning field
|
||||||
|
|
||||||
|
<PREVIOUS MESSAGES>
|
||||||
|
{context['previous_messages']}
|
||||||
|
</PREVIOUS MESSAGES>
|
||||||
|
<MESSAGE>
|
||||||
|
{context['message']}
|
||||||
|
</MESSAGE>
|
||||||
|
|
||||||
|
<BASELINE>
|
||||||
|
{context['baseline']}
|
||||||
|
</BASELINE>
|
||||||
|
|
||||||
|
<CANDIDATE>
|
||||||
|
{context['candidate']}
|
||||||
|
</CANDIDATE>
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
Message(role='system', content=sys_prompt),
|
||||||
|
Message(role='user', content=user_prompt),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
versions: Versions = {
|
versions: Versions = {
|
||||||
'qa_prompt': qa_prompt,
|
'qa_prompt': qa_prompt,
|
||||||
'eval_prompt': eval_prompt,
|
'eval_prompt': eval_prompt,
|
||||||
'query_expansion': query_expansion,
|
'query_expansion': query_expansion,
|
||||||
|
'eval_add_episode_results': eval_add_episode_results,
|
||||||
}
|
}
|
||||||
|
3
tests/evals/data/longmemeval_data/README.md
Normal file
3
tests/evals/data/longmemeval_data/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
The `longmemeval_oracle` dataset is an open-source dataset that we are using.
|
||||||
|
We did not create this dataset and it can be found
|
||||||
|
here: https://huggingface.co/datasets/xiaowu0162/longmemeval/blob/main/longmemeval_oracle.
|
67042
tests/evals/data/longmemeval_data/longmemeval_oracle.json
Normal file
67042
tests/evals/data/longmemeval_data/longmemeval_oracle.json
Normal file
File diff suppressed because one or more lines are too long
39
tests/evals/eval_cli.py
Normal file
39
tests/evals/eval_cli.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from tests.evals.eval_e2e_graph_building import build_baseline_graph, eval_graph
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Run eval_graph and optionally build_baseline_graph from the command line.'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--multi-session',
|
||||||
|
type=int,
|
||||||
|
nargs='+',
|
||||||
|
required=True,
|
||||||
|
help='List of integers representing multi-session values (e.g., 1 2 3)',
|
||||||
|
)
|
||||||
|
parser.add_argument('--session-length', type=int, required=True, help='Length of each session')
|
||||||
|
parser.add_argument(
|
||||||
|
'--build-baseline', action='store_true', help='If set, also runs build_baseline_graph'
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Optionally run the async function
|
||||||
|
if args.build_baseline:
|
||||||
|
print('Running build_baseline_graph...')
|
||||||
|
await build_baseline_graph(
|
||||||
|
multi_session=args.multi_session, session_length=args.session_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# Always call eval_graph
|
||||||
|
result = await eval_graph(multi_session=args.multi_session, session_length=args.session_length)
|
||||||
|
print('Result of eval_graph:', result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
157
tests/evals/eval_e2e_graph_building.py
Normal file
157
tests/evals/eval_e2e_graph_building.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from graphiti_core import Graphiti
|
||||||
|
from graphiti_core.graphiti import AddEpisodeResults
|
||||||
|
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||||
|
from graphiti_core.nodes import EpisodeType
|
||||||
|
from graphiti_core.prompts import prompt_library
|
||||||
|
from graphiti_core.prompts.eval import EvalAddEpisodeResults
|
||||||
|
from graphiti_core.utils.maintenance import clear_data
|
||||||
|
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
|
||||||
|
|
||||||
|
|
||||||
|
async def build_graph(
|
||||||
|
group_id_suffix: str, multi_session: list[int], session_length: int, graphiti: Graphiti
|
||||||
|
) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
|
||||||
|
# Get longmemeval dataset
|
||||||
|
lme_dataset_option = (
|
||||||
|
'data/longmemeval_data/longmemeval_oracle.json' # Can be _oracle, _s, or _m
|
||||||
|
)
|
||||||
|
lme_dataset_df = pd.read_json(lme_dataset_option)
|
||||||
|
|
||||||
|
add_episode_results: dict[str, list[AddEpisodeResults]] = {}
|
||||||
|
add_episode_context: dict[str, list[str]] = {}
|
||||||
|
for multi_session_idx in multi_session:
|
||||||
|
multi_session = lme_dataset_df['haystack_sessions'].iloc[multi_session_idx]
|
||||||
|
multi_session_dates = lme_dataset_df['haystack_dates'].iloc[multi_session_idx]
|
||||||
|
|
||||||
|
user_id = 'lme_oracle_experiment_user_' + str(multi_session_idx)
|
||||||
|
await clear_data(graphiti.driver, [user_id])
|
||||||
|
|
||||||
|
add_episode_results[user_id] = []
|
||||||
|
add_episode_context[user_id] = []
|
||||||
|
|
||||||
|
message_count = 0
|
||||||
|
for session_idx, session in enumerate(multi_session):
|
||||||
|
for _, msg in enumerate(session):
|
||||||
|
if message_count >= session_length:
|
||||||
|
continue
|
||||||
|
message_count += 1
|
||||||
|
date = multi_session_dates[session_idx] + ' UTC'
|
||||||
|
date_format = '%Y/%m/%d (%a) %H:%M UTC'
|
||||||
|
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
episode_body = f'{msg["role"]}: {msg["content"]}'
|
||||||
|
results = await graphiti.add_episode(
|
||||||
|
name='',
|
||||||
|
episode_body=episode_body,
|
||||||
|
reference_time=date_string,
|
||||||
|
source=EpisodeType.message,
|
||||||
|
source_description='',
|
||||||
|
group_id=user_id + '_' + group_id_suffix,
|
||||||
|
)
|
||||||
|
for node in results.nodes:
|
||||||
|
node.name_embedding = None
|
||||||
|
for edge in results.edges:
|
||||||
|
edge.fact_embedding = None
|
||||||
|
|
||||||
|
add_episode_results[user_id].append(results)
|
||||||
|
add_episode_context[user_id].append(msg['content'])
|
||||||
|
return add_episode_results, add_episode_context
|
||||||
|
|
||||||
|
|
||||||
|
async def build_baseline_graph(multi_session: list[int], session_length: int):
|
||||||
|
# Use gpt-4o for graph building baseline
|
||||||
|
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4o'))
|
||||||
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
||||||
|
|
||||||
|
add_episode_results, _ = await build_graph('baseline', multi_session, session_length, graphiti)
|
||||||
|
|
||||||
|
filename = 'baseline_graph_results.json'
|
||||||
|
|
||||||
|
serializable_baseline_graph_results = {
|
||||||
|
key: [item.model_dump(mode='json') for item in value]
|
||||||
|
for key, value in add_episode_results.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(filename, 'w') as file:
|
||||||
|
json.dump(serializable_baseline_graph_results, file, indent=4, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
async def eval_graph(multi_session: list[int], session_length: int, llm_client=None) -> float:
|
||||||
|
if llm_client is None:
|
||||||
|
llm_client = OpenAIClient()
|
||||||
|
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
||||||
|
with open('baseline_graph_results.json') as file:
|
||||||
|
baseline_results_raw = json.load(file)
|
||||||
|
|
||||||
|
baseline_results: dict[str, list[AddEpisodeResults]] = {
|
||||||
|
key: [AddEpisodeResults(**item) for item in value]
|
||||||
|
for key, value in baseline_results_raw.items()
|
||||||
|
}
|
||||||
|
add_episode_results, add_episode_context = await build_graph(
|
||||||
|
'candidate', multi_session, session_length, graphiti
|
||||||
|
)
|
||||||
|
|
||||||
|
filename = 'candidate_graph_results.json'
|
||||||
|
|
||||||
|
candidate_baseline_graph_results = {
|
||||||
|
key: [item.model_dump(mode='json') for item in value]
|
||||||
|
for key, value in add_episode_results.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(filename, 'w') as file:
|
||||||
|
json.dump(candidate_baseline_graph_results, file, indent=4, default=str)
|
||||||
|
|
||||||
|
raw_score = 0
|
||||||
|
user_count = 0
|
||||||
|
for user_id in add_episode_results:
|
||||||
|
user_count += 1
|
||||||
|
user_raw_score = 0
|
||||||
|
print('add_episode_context: ', add_episode_context)
|
||||||
|
for baseline_result, add_episode_result, episodes in zip(
|
||||||
|
baseline_results[user_id],
|
||||||
|
add_episode_results[user_id],
|
||||||
|
add_episode_context[user_id],
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
context = {
|
||||||
|
'baseline': baseline_result,
|
||||||
|
'candidate': add_episode_result,
|
||||||
|
'message': episodes[0],
|
||||||
|
'previous_messages': episodes[1:],
|
||||||
|
}
|
||||||
|
print(context)
|
||||||
|
|
||||||
|
llm_response = await llm_client.generate_response(
|
||||||
|
prompt_library.eval.eval_add_episode_results(context),
|
||||||
|
response_model=EvalAddEpisodeResults,
|
||||||
|
)
|
||||||
|
|
||||||
|
candidate_is_worse = llm_response.get('candidate_is_worse', False)
|
||||||
|
user_raw_score += 0 if candidate_is_worse else 1
|
||||||
|
print('llm_response:', llm_response)
|
||||||
|
user_score = user_raw_score / len(add_episode_results[user_id])
|
||||||
|
raw_score += user_score
|
||||||
|
score = raw_score / user_count
|
||||||
|
|
||||||
|
return score
|
4
tests/evals/pytest.ini
Normal file
4
tests/evals/pytest.ini
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[pytest]
|
||||||
|
asyncio_default_fixture_loop_scope = function
|
||||||
|
markers =
|
||||||
|
integration: marks tests as integration tests
|
39
tests/evals/utils.py
Normal file
39
tests/evals/utils.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2024, Zep Software, Inc.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
# Create a logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||||
|
|
||||||
|
# Create console handler and set level to INFO
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# Create formatter
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
# Add formatter to console handler
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# Add console handler to logger
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
return logger
|
Loading…
x
Reference in New Issue
Block a user