mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
e2e graph builder eval (#343)
* add partial eval platform * dedupe updates * add e2e eval * Update graphiti_core/prompts/eval.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * clear all outputs * clear all outputs * squash eval commits * Update tests/evals/data/utils.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * add longmemeval disclaimer * remove gitignore * add copyright headers * add cli * Update tests/evals/data/longmemeval_data/README.md Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update tests/evals/eval_e2e_graph_building.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * updates --------- Co-authored-by: jackaldenryan <jackaldenryan@gmail.com> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
6aa25a1901
commit
5dce26722e
3
.gitignore
vendored
3
.gitignore
vendored
@ -164,3 +164,6 @@ cython_debug/
|
||||
## Other
|
||||
# Cache files
|
||||
cache.db*
|
||||
|
||||
# All DS_Store files
|
||||
.DS_Store
|
@ -37,16 +37,28 @@ class EvalResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class EvalAddEpisodeResults(BaseModel):
|
||||
candidate_is_worse: bool = Field(
|
||||
...,
|
||||
description='boolean if the baseline extraction is higher quality than the candidate extraction.',
|
||||
)
|
||||
reasoning: str = Field(
|
||||
..., description='why you determined the response was correct or incorrect'
|
||||
)
|
||||
|
||||
|
||||
class Prompt(Protocol):
|
||||
qa_prompt: PromptVersion
|
||||
eval_prompt: PromptVersion
|
||||
query_expansion: PromptVersion
|
||||
eval_add_episode_results: PromptVersion
|
||||
|
||||
|
||||
class Versions(TypedDict):
|
||||
qa_prompt: PromptFunction
|
||||
eval_prompt: PromptFunction
|
||||
query_expansion: PromptFunction
|
||||
eval_add_episode_results: PromptFunction
|
||||
|
||||
|
||||
def query_expansion(context: dict[str, Any]) -> list[Message]:
|
||||
@ -112,8 +124,41 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
||||
]
|
||||
|
||||
|
||||
def eval_add_episode_results(context: dict[str, Any]) -> list[Message]:
|
||||
sys_prompt = """You are a judge that determines whether a baseline graph building result from a list of messages is better
|
||||
than a candidate graph building result based on the same messages."""
|
||||
|
||||
user_prompt = f"""
|
||||
Given the following PREVIOUS MESSAGES and MESSAGE, determine if the BASELINE graph data extracted from the
|
||||
conversation is higher quality than the CANDIDATE graph data extracted from the conversation.
|
||||
|
||||
Return False if the BASELINE extraction is better, and True otherwise. If the CANDIDATE extraction and
|
||||
BASELINE extraction are nearly identical in quality, return True. Add your reasoning for your decision to the reasoning field
|
||||
|
||||
<PREVIOUS MESSAGES>
|
||||
{context['previous_messages']}
|
||||
</PREVIOUS MESSAGES>
|
||||
<MESSAGE>
|
||||
{context['message']}
|
||||
</MESSAGE>
|
||||
|
||||
<BASELINE>
|
||||
{context['baseline']}
|
||||
</BASELINE>
|
||||
|
||||
<CANDIDATE>
|
||||
{context['candidate']}
|
||||
</CANDIDATE>
|
||||
"""
|
||||
return [
|
||||
Message(role='system', content=sys_prompt),
|
||||
Message(role='user', content=user_prompt),
|
||||
]
|
||||
|
||||
|
||||
versions: Versions = {
|
||||
'qa_prompt': qa_prompt,
|
||||
'eval_prompt': eval_prompt,
|
||||
'query_expansion': query_expansion,
|
||||
'eval_add_episode_results': eval_add_episode_results,
|
||||
}
|
||||
|
3
tests/evals/data/longmemeval_data/README.md
Normal file
3
tests/evals/data/longmemeval_data/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
The `longmemeval_oracle` dataset is an open-source dataset that we are using.
|
||||
We did not create this dataset and it can be found
|
||||
here: https://huggingface.co/datasets/xiaowu0162/longmemeval/blob/main/longmemeval_oracle.
|
67042
tests/evals/data/longmemeval_data/longmemeval_oracle.json
Normal file
67042
tests/evals/data/longmemeval_data/longmemeval_oracle.json
Normal file
File diff suppressed because one or more lines are too long
39
tests/evals/eval_cli.py
Normal file
39
tests/evals/eval_cli.py
Normal file
@ -0,0 +1,39 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from tests.evals.eval_e2e_graph_building import build_baseline_graph, eval_graph
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Run eval_graph and optionally build_baseline_graph from the command line.'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--multi-session',
|
||||
type=int,
|
||||
nargs='+',
|
||||
required=True,
|
||||
help='List of integers representing multi-session values (e.g., 1 2 3)',
|
||||
)
|
||||
parser.add_argument('--session-length', type=int, required=True, help='Length of each session')
|
||||
parser.add_argument(
|
||||
'--build-baseline', action='store_true', help='If set, also runs build_baseline_graph'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Optionally run the async function
|
||||
if args.build_baseline:
|
||||
print('Running build_baseline_graph...')
|
||||
await build_baseline_graph(
|
||||
multi_session=args.multi_session, session_length=args.session_length
|
||||
)
|
||||
|
||||
# Always call eval_graph
|
||||
result = await eval_graph(multi_session=args.multi_session, session_length=args.session_length)
|
||||
print('Result of eval_graph:', result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
157
tests/evals/eval_e2e_graph_building.py
Normal file
157
tests/evals/eval_e2e_graph_building.py
Normal file
@ -0,0 +1,157 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.graphiti import AddEpisodeResults
|
||||
from graphiti_core.llm_client import LLMConfig, OpenAIClient
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.eval import EvalAddEpisodeResults
|
||||
from graphiti_core.utils.maintenance import clear_data
|
||||
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
|
||||
|
||||
|
||||
async def build_graph(
|
||||
group_id_suffix: str, multi_session: list[int], session_length: int, graphiti: Graphiti
|
||||
) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
|
||||
# Get longmemeval dataset
|
||||
lme_dataset_option = (
|
||||
'data/longmemeval_data/longmemeval_oracle.json' # Can be _oracle, _s, or _m
|
||||
)
|
||||
lme_dataset_df = pd.read_json(lme_dataset_option)
|
||||
|
||||
add_episode_results: dict[str, list[AddEpisodeResults]] = {}
|
||||
add_episode_context: dict[str, list[str]] = {}
|
||||
for multi_session_idx in multi_session:
|
||||
multi_session = lme_dataset_df['haystack_sessions'].iloc[multi_session_idx]
|
||||
multi_session_dates = lme_dataset_df['haystack_dates'].iloc[multi_session_idx]
|
||||
|
||||
user_id = 'lme_oracle_experiment_user_' + str(multi_session_idx)
|
||||
await clear_data(graphiti.driver, [user_id])
|
||||
|
||||
add_episode_results[user_id] = []
|
||||
add_episode_context[user_id] = []
|
||||
|
||||
message_count = 0
|
||||
for session_idx, session in enumerate(multi_session):
|
||||
for _, msg in enumerate(session):
|
||||
if message_count >= session_length:
|
||||
continue
|
||||
message_count += 1
|
||||
date = multi_session_dates[session_idx] + ' UTC'
|
||||
date_format = '%Y/%m/%d (%a) %H:%M UTC'
|
||||
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
|
||||
|
||||
episode_body = f'{msg["role"]}: {msg["content"]}'
|
||||
results = await graphiti.add_episode(
|
||||
name='',
|
||||
episode_body=episode_body,
|
||||
reference_time=date_string,
|
||||
source=EpisodeType.message,
|
||||
source_description='',
|
||||
group_id=user_id + '_' + group_id_suffix,
|
||||
)
|
||||
for node in results.nodes:
|
||||
node.name_embedding = None
|
||||
for edge in results.edges:
|
||||
edge.fact_embedding = None
|
||||
|
||||
add_episode_results[user_id].append(results)
|
||||
add_episode_context[user_id].append(msg['content'])
|
||||
return add_episode_results, add_episode_context
|
||||
|
||||
|
||||
async def build_baseline_graph(multi_session: list[int], session_length: int):
|
||||
# Use gpt-4o for graph building baseline
|
||||
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4o'))
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
||||
|
||||
add_episode_results, _ = await build_graph('baseline', multi_session, session_length, graphiti)
|
||||
|
||||
filename = 'baseline_graph_results.json'
|
||||
|
||||
serializable_baseline_graph_results = {
|
||||
key: [item.model_dump(mode='json') for item in value]
|
||||
for key, value in add_episode_results.items()
|
||||
}
|
||||
|
||||
with open(filename, 'w') as file:
|
||||
json.dump(serializable_baseline_graph_results, file, indent=4, default=str)
|
||||
|
||||
|
||||
async def eval_graph(multi_session: list[int], session_length: int, llm_client=None) -> float:
|
||||
if llm_client is None:
|
||||
llm_client = OpenAIClient()
|
||||
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
|
||||
with open('baseline_graph_results.json') as file:
|
||||
baseline_results_raw = json.load(file)
|
||||
|
||||
baseline_results: dict[str, list[AddEpisodeResults]] = {
|
||||
key: [AddEpisodeResults(**item) for item in value]
|
||||
for key, value in baseline_results_raw.items()
|
||||
}
|
||||
add_episode_results, add_episode_context = await build_graph(
|
||||
'candidate', multi_session, session_length, graphiti
|
||||
)
|
||||
|
||||
filename = 'candidate_graph_results.json'
|
||||
|
||||
candidate_baseline_graph_results = {
|
||||
key: [item.model_dump(mode='json') for item in value]
|
||||
for key, value in add_episode_results.items()
|
||||
}
|
||||
|
||||
with open(filename, 'w') as file:
|
||||
json.dump(candidate_baseline_graph_results, file, indent=4, default=str)
|
||||
|
||||
raw_score = 0
|
||||
user_count = 0
|
||||
for user_id in add_episode_results:
|
||||
user_count += 1
|
||||
user_raw_score = 0
|
||||
print('add_episode_context: ', add_episode_context)
|
||||
for baseline_result, add_episode_result, episodes in zip(
|
||||
baseline_results[user_id],
|
||||
add_episode_results[user_id],
|
||||
add_episode_context[user_id],
|
||||
strict=True,
|
||||
):
|
||||
context = {
|
||||
'baseline': baseline_result,
|
||||
'candidate': add_episode_result,
|
||||
'message': episodes[0],
|
||||
'previous_messages': episodes[1:],
|
||||
}
|
||||
print(context)
|
||||
|
||||
llm_response = await llm_client.generate_response(
|
||||
prompt_library.eval.eval_add_episode_results(context),
|
||||
response_model=EvalAddEpisodeResults,
|
||||
)
|
||||
|
||||
candidate_is_worse = llm_response.get('candidate_is_worse', False)
|
||||
user_raw_score += 0 if candidate_is_worse else 1
|
||||
print('llm_response:', llm_response)
|
||||
user_score = user_raw_score / len(add_episode_results[user_id])
|
||||
raw_score += user_score
|
||||
score = raw_score / user_count
|
||||
|
||||
return score
|
4
tests/evals/pytest.ini
Normal file
4
tests/evals/pytest.ini
Normal file
@ -0,0 +1,4 @@
|
||||
[pytest]
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
markers =
|
||||
integration: marks tests as integration tests
|
39
tests/evals/utils.py
Normal file
39
tests/evals/utils.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Create a logger
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
||||
|
||||
# Create console handler and set level to INFO
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Add formatter to console handler
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
Loading…
x
Reference in New Issue
Block a user