e2e graph builder eval (#343)

* add partial eval platform

* dedupe updates

* add e2e eval

* Update graphiti_core/prompts/eval.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* clear all outputs

* clear all outputs

* squash eval commits

* Update tests/evals/data/utils.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* add longmemeval disclaimer

* remove gitignore

* add copyright headers

* add cli

* Update tests/evals/data/longmemeval_data/README.md

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* Update tests/evals/eval_e2e_graph_building.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* updates

---------

Co-authored-by: jackaldenryan <jackaldenryan@gmail.com>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Preston Rasmussen 2025-04-12 10:35:22 -04:00 committed by GitHub
parent 6aa25a1901
commit 5dce26722e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 67332 additions and 0 deletions

3
.gitignore vendored
View File

@ -164,3 +164,6 @@ cython_debug/
## Other
# Cache files
cache.db*
# All DS_Store files
.DS_Store

View File

@ -37,16 +37,28 @@ class EvalResponse(BaseModel):
)
class EvalAddEpisodeResults(BaseModel):
candidate_is_worse: bool = Field(
...,
description='boolean if the baseline extraction is higher quality than the candidate extraction.',
)
reasoning: str = Field(
..., description='why you determined the response was correct or incorrect'
)
class Prompt(Protocol):
qa_prompt: PromptVersion
eval_prompt: PromptVersion
query_expansion: PromptVersion
eval_add_episode_results: PromptVersion
class Versions(TypedDict):
qa_prompt: PromptFunction
eval_prompt: PromptFunction
query_expansion: PromptFunction
eval_add_episode_results: PromptFunction
def query_expansion(context: dict[str, Any]) -> list[Message]:
@ -112,8 +124,41 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
]
def eval_add_episode_results(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are a judge that determines whether a baseline graph building result from a list of messages is better
than a candidate graph building result based on the same messages."""
user_prompt = f"""
Given the following PREVIOUS MESSAGES and MESSAGE, determine if the BASELINE graph data extracted from the
conversation is higher quality than the CANDIDATE graph data extracted from the conversation.
Return False if the BASELINE extraction is better, and True otherwise. If the CANDIDATE extraction and
BASELINE extraction are nearly identical in quality, return True. Add your reasoning for your decision to the reasoning field
<PREVIOUS MESSAGES>
{context['previous_messages']}
</PREVIOUS MESSAGES>
<MESSAGE>
{context['message']}
</MESSAGE>
<BASELINE>
{context['baseline']}
</BASELINE>
<CANDIDATE>
{context['candidate']}
</CANDIDATE>
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]
versions: Versions = {
'qa_prompt': qa_prompt,
'eval_prompt': eval_prompt,
'query_expansion': query_expansion,
'eval_add_episode_results': eval_add_episode_results,
}

View File

@ -0,0 +1,3 @@
The `longmemeval_oracle` dataset is an open-source dataset that we are using.
We did not create this dataset and it can be found
here: https://huggingface.co/datasets/xiaowu0162/longmemeval/blob/main/longmemeval_oracle.

File diff suppressed because one or more lines are too long

39
tests/evals/eval_cli.py Normal file
View File

@ -0,0 +1,39 @@
import argparse
import asyncio
from tests.evals.eval_e2e_graph_building import build_baseline_graph, eval_graph
async def main():
parser = argparse.ArgumentParser(
description='Run eval_graph and optionally build_baseline_graph from the command line.'
)
parser.add_argument(
'--multi-session',
type=int,
nargs='+',
required=True,
help='List of integers representing multi-session values (e.g., 1 2 3)',
)
parser.add_argument('--session-length', type=int, required=True, help='Length of each session')
parser.add_argument(
'--build-baseline', action='store_true', help='If set, also runs build_baseline_graph'
)
args = parser.parse_args()
# Optionally run the async function
if args.build_baseline:
print('Running build_baseline_graph...')
await build_baseline_graph(
multi_session=args.multi_session, session_length=args.session_length
)
# Always call eval_graph
result = await eval_graph(multi_session=args.multi_session, session_length=args.session_length)
print('Result of eval_graph:', result)
if __name__ == '__main__':
asyncio.run(main())

View File

@ -0,0 +1,157 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from datetime import datetime, timezone
import pandas as pd
from graphiti_core import Graphiti
from graphiti_core.graphiti import AddEpisodeResults
from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EpisodeType
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.eval import EvalAddEpisodeResults
from graphiti_core.utils.maintenance import clear_data
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
async def build_graph(
group_id_suffix: str, multi_session: list[int], session_length: int, graphiti: Graphiti
) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
# Get longmemeval dataset
lme_dataset_option = (
'data/longmemeval_data/longmemeval_oracle.json' # Can be _oracle, _s, or _m
)
lme_dataset_df = pd.read_json(lme_dataset_option)
add_episode_results: dict[str, list[AddEpisodeResults]] = {}
add_episode_context: dict[str, list[str]] = {}
for multi_session_idx in multi_session:
multi_session = lme_dataset_df['haystack_sessions'].iloc[multi_session_idx]
multi_session_dates = lme_dataset_df['haystack_dates'].iloc[multi_session_idx]
user_id = 'lme_oracle_experiment_user_' + str(multi_session_idx)
await clear_data(graphiti.driver, [user_id])
add_episode_results[user_id] = []
add_episode_context[user_id] = []
message_count = 0
for session_idx, session in enumerate(multi_session):
for _, msg in enumerate(session):
if message_count >= session_length:
continue
message_count += 1
date = multi_session_dates[session_idx] + ' UTC'
date_format = '%Y/%m/%d (%a) %H:%M UTC'
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
episode_body = f'{msg["role"]}: {msg["content"]}'
results = await graphiti.add_episode(
name='',
episode_body=episode_body,
reference_time=date_string,
source=EpisodeType.message,
source_description='',
group_id=user_id + '_' + group_id_suffix,
)
for node in results.nodes:
node.name_embedding = None
for edge in results.edges:
edge.fact_embedding = None
add_episode_results[user_id].append(results)
add_episode_context[user_id].append(msg['content'])
return add_episode_results, add_episode_context
async def build_baseline_graph(multi_session: list[int], session_length: int):
# Use gpt-4o for graph building baseline
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4o'))
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
add_episode_results, _ = await build_graph('baseline', multi_session, session_length, graphiti)
filename = 'baseline_graph_results.json'
serializable_baseline_graph_results = {
key: [item.model_dump(mode='json') for item in value]
for key, value in add_episode_results.items()
}
with open(filename, 'w') as file:
json.dump(serializable_baseline_graph_results, file, indent=4, default=str)
async def eval_graph(multi_session: list[int], session_length: int, llm_client=None) -> float:
if llm_client is None:
llm_client = OpenAIClient()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
with open('baseline_graph_results.json') as file:
baseline_results_raw = json.load(file)
baseline_results: dict[str, list[AddEpisodeResults]] = {
key: [AddEpisodeResults(**item) for item in value]
for key, value in baseline_results_raw.items()
}
add_episode_results, add_episode_context = await build_graph(
'candidate', multi_session, session_length, graphiti
)
filename = 'candidate_graph_results.json'
candidate_baseline_graph_results = {
key: [item.model_dump(mode='json') for item in value]
for key, value in add_episode_results.items()
}
with open(filename, 'w') as file:
json.dump(candidate_baseline_graph_results, file, indent=4, default=str)
raw_score = 0
user_count = 0
for user_id in add_episode_results:
user_count += 1
user_raw_score = 0
print('add_episode_context: ', add_episode_context)
for baseline_result, add_episode_result, episodes in zip(
baseline_results[user_id],
add_episode_results[user_id],
add_episode_context[user_id],
strict=True,
):
context = {
'baseline': baseline_result,
'candidate': add_episode_result,
'message': episodes[0],
'previous_messages': episodes[1:],
}
print(context)
llm_response = await llm_client.generate_response(
prompt_library.eval.eval_add_episode_results(context),
response_model=EvalAddEpisodeResults,
)
candidate_is_worse = llm_response.get('candidate_is_worse', False)
user_raw_score += 0 if candidate_is_worse else 1
print('llm_response:', llm_response)
user_score = user_raw_score / len(add_episode_results[user_id])
raw_score += user_score
score = raw_score / user_count
return score

4
tests/evals/pytest.ini Normal file
View File

@ -0,0 +1,4 @@
[pytest]
asyncio_default_fixture_loop_scope = function
markers =
integration: marks tests as integration tests

39
tests/evals/utils.py Normal file
View File

@ -0,0 +1,39 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import sys
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger