refactor!: use Chat Generator in LLM evaluators (#9116)

* use chatgenerator instead of generator

* rename generator to _chat_generator

* rm print

* Update releasenotes/notes/llm-evaluators-chat-generator-bf930fa6db019714.yaml

Co-authored-by: David S. Batista <dsbatista@gmail.com>

---------

Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
Stefano Fiorucci 2025-03-26 15:38:56 +01:00 committed by GitHub
parent 13941d8bd9
commit e4cf460bf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 72 additions and 53 deletions

View File

@ -9,7 +9,8 @@ from tqdm import tqdm
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.builders import PromptBuilder
from haystack.components.generators import OpenAIGenerator
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.dataclasses.chat_message import ChatMessage
from haystack.utils import Secret, deserialize_secrets_inplace, deserialize_type, serialize_type
logger = logging.getLogger(__name__)
@ -110,7 +111,7 @@ class LLMEvaluator:
generator_kwargs = {**self.api_params}
if api_key:
generator_kwargs["api_key"] = api_key
self.generator = OpenAIGenerator(**generator_kwargs)
self._chat_generator = OpenAIChatGenerator(**generator_kwargs)
else:
raise ValueError(f"Unsupported API: {api}")
@ -200,12 +201,13 @@ class LLMEvaluator:
list_of_input_names_to_values = [dict(zip(input_names, v)) for v in values]
results: List[Optional[Dict[str, Any]]] = []
metadata = None
metadata = []
errors = 0
for input_names_to_values in tqdm(list_of_input_names_to_values, disable=not self.progress_bar):
prompt = self.builder.run(**input_names_to_values)
messages = [ChatMessage.from_user(prompt["prompt"])]
try:
result = self.generator.run(prompt=prompt["prompt"])
result = self._chat_generator.run(messages=messages)
except Exception as e:
if self.raise_on_failure:
raise ValueError(f"Error while generating response for prompt: {prompt}. Error: {e}")
@ -214,15 +216,15 @@ class LLMEvaluator:
errors += 1
continue
if self.is_valid_json_and_has_expected_keys(expected=self.outputs, received=result["replies"][0]):
parsed_result = json.loads(result["replies"][0])
if self.is_valid_json_and_has_expected_keys(expected=self.outputs, received=result["replies"][0].text):
parsed_result = json.loads(result["replies"][0].text)
results.append(parsed_result)
else:
results.append(None)
errors += 1
if self.api == "openai" and "meta" in result:
metadata = result["meta"]
if result["replies"][0].meta:
metadata.append(result["replies"][0].meta)
if errors > 0:
logger.warning(
@ -231,7 +233,7 @@ class LLMEvaluator:
len=len(list_of_input_names_to_values),
)
return {"results": results, "meta": metadata}
return {"results": results, "meta": metadata or None}
def prepare_template(self) -> str:
"""

View File

@ -0,0 +1,6 @@
---
upgrade:
- |
`LLMEvaluator`, `ContextRelevanceEvaluator`, and `FaithfulnessEvaluator` now internally use a
`ChatGenerator` instance instead of a `Generator` instance.
The public attribute `generator` has been replaced with `_chat_generator`.

View File

@ -11,6 +11,7 @@ import pytest
from haystack import Pipeline
from haystack.components.evaluators import ContextRelevanceEvaluator
from haystack.utils.auth import Secret
from haystack.dataclasses.chat_message import ChatMessage
class TestContextRelevanceEvaluator:
@ -18,7 +19,7 @@ class TestContextRelevanceEvaluator:
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = ContextRelevanceEvaluator()
assert component.api == "openai"
assert component.generator.client.api_key == "test-api-key"
assert component._chat_generator.client.api_key == "test-api-key"
assert component.instructions == (
"Please extract only sentences from the provided context which are absolutely relevant and "
"required to answer the following question. If no relevant sentences are found, or if you "
@ -65,7 +66,7 @@ class TestContextRelevanceEvaluator:
{"inputs": {"questions": "Football is the most popular sport."}, "outputs": {"custom_score": 0}},
],
)
assert component.generator.client.api_key == "test-api-key"
assert component._chat_generator.client.api_key == "test-api-key"
assert component.api == "openai"
assert component.examples == [
{"inputs": {"questions": "Damn, this is straight outta hell!!!"}, "outputs": {"custom_score": 1}},
@ -107,7 +108,7 @@ class TestContextRelevanceEvaluator:
}
component = ContextRelevanceEvaluator.from_dict(data)
assert component.api == "openai"
assert component.generator.client.api_key == "test-api-key"
assert component._chat_generator.client.api_key == "test-api-key"
assert component.examples == [{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}]
pipeline = Pipeline()
@ -118,13 +119,13 @@ class TestContextRelevanceEvaluator:
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = ContextRelevanceEvaluator()
def generator_run(self, *args, **kwargs):
if "Football" in kwargs["prompt"]:
return {"replies": ['{"relevant_statements": ["a", "b"], "score": 1}']}
def chat_generator_run(self, *args, **kwargs):
if "Football" in kwargs["messages"][0].text:
return {"replies": [ChatMessage.from_assistant('{"relevant_statements": ["a", "b"], "score": 1}')]}
else:
return {"replies": ['{"relevant_statements": [], "score": 0}']}
return {"replies": [ChatMessage.from_assistant('{"relevant_statements": [], "score": 0}')]}
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
contexts = [
@ -152,13 +153,13 @@ class TestContextRelevanceEvaluator:
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = ContextRelevanceEvaluator()
def generator_run(self, *args, **kwargs):
if "Football" in kwargs["prompt"]:
return {"replies": ['{"relevant_statements": ["a", "b"], "score": 1}']}
def chat_generator_run(self, *args, **kwargs):
if "Football" in kwargs["messages"][0].text:
return {"replies": [ChatMessage.from_assistant('{"relevant_statements": ["a", "b"], "score": 1}')]}
else:
return {"replies": ['{"relevant_statements": [], "score": 0}']}
return {"replies": [ChatMessage.from_assistant('{"relevant_statements": [], "score": 0}')]}
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
contexts = [
@ -188,13 +189,13 @@ class TestContextRelevanceEvaluator:
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = ContextRelevanceEvaluator(raise_on_failure=False)
def generator_run(self, *args, **kwargs):
if "Python" in kwargs["prompt"]:
def chat_generator_run(self, *args, **kwargs):
if "Python" in kwargs["messages"][0].text:
raise Exception("OpenAI API request failed.")
else:
return {"replies": ['{"relevant_statements": ["c", "d"], "score": 1}']}
return {"replies": [ChatMessage.from_assistant('{"relevant_statements": ["c", "d"], "score": 1}')]}
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
contexts = [

View File

@ -10,6 +10,7 @@ import pytest
from haystack import Pipeline
from haystack.components.evaluators import FaithfulnessEvaluator
from haystack.utils.auth import Secret
from haystack.dataclasses.chat_message import ChatMessage
class TestFaithfulnessEvaluator:
@ -17,7 +18,7 @@ class TestFaithfulnessEvaluator:
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = FaithfulnessEvaluator()
assert component.api == "openai"
assert component.generator.client.api_key == "test-api-key"
assert component._chat_generator.client.api_key == "test-api-key"
assert component.instructions == (
"Your task is to judge the faithfulness or groundedness of statements based "
"on context information. First, please extract statements from a provided predicted "
@ -84,7 +85,7 @@ class TestFaithfulnessEvaluator:
},
],
)
assert component.generator.client.api_key == "test-api-key"
assert component._chat_generator.client.api_key == "test-api-key"
assert component.api == "openai"
assert component.examples == [
{"inputs": {"predicted_answers": "Damn, this is straight outta hell!!!"}, "outputs": {"custom_score": 1}},
@ -132,7 +133,7 @@ class TestFaithfulnessEvaluator:
}
component = FaithfulnessEvaluator.from_dict(data)
assert component.api == "openai"
assert component.generator.client.api_key == "test-api-key"
assert component._chat_generator.client.api_key == "test-api-key"
assert component.examples == [
{"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}}
]
@ -145,13 +146,17 @@ class TestFaithfulnessEvaluator:
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = FaithfulnessEvaluator()
def generator_run(self, *args, **kwargs):
if "Football" in kwargs["prompt"]:
return {"replies": ['{"statements": ["a", "b"], "statement_scores": [1, 0]}']}
def chat_generator_run(self, *args, **kwargs):
if "Football" in kwargs["messages"][0].text:
return {
"replies": [ChatMessage.from_assistant('{"statements": ["a", "b"], "statement_scores": [1, 0]}')]
}
else:
return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']}
return {
"replies": [ChatMessage.from_assistant('{"statements": ["c", "d"], "statement_scores": [1, 1]}')]
}
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
contexts = [
@ -186,13 +191,15 @@ class TestFaithfulnessEvaluator:
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = FaithfulnessEvaluator()
def generator_run(self, *args, **kwargs):
if "Football" in kwargs["prompt"]:
return {"replies": ['{"statements": ["a", "b"], "statement_scores": [1, 0]}']}
def chat_generator_run(self, *args, **kwargs):
if "Football" in kwargs["messages"][0].text:
return {
"replies": [ChatMessage.from_assistant('{"statements": ["a", "b"], "statement_scores": [1, 0]}')]
}
else:
return {"replies": ['{"statements": [], "statement_scores": []}']}
return {"replies": [ChatMessage.from_assistant('{"statements": [], "statement_scores": []}')]}
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
contexts = [
@ -229,13 +236,15 @@ class TestFaithfulnessEvaluator:
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = FaithfulnessEvaluator(raise_on_failure=False)
def generator_run(self, *args, **kwargs):
if "Python" in kwargs["prompt"]:
def chat_generator_run(self, *args, **kwargs):
if "Python" in kwargs["messages"][0].text:
raise Exception("OpenAI API request failed.")
else:
return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']}
return {
"replies": [ChatMessage.from_assistant('{"statements": ["c", "d"], "statement_scores": [1, 1]}')]
}
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
contexts = [

View File

@ -9,6 +9,7 @@ import pytest
from haystack import Pipeline
from haystack.components.evaluators import LLMEvaluator
from haystack.utils.auth import Secret
from haystack.dataclasses.chat_message import ChatMessage
class TestLLMEvaluator:
@ -23,7 +24,7 @@ class TestLLMEvaluator:
],
)
assert component.api == "openai"
assert component.generator.client.api_key == "test-api-key"
assert component._chat_generator.client.api_key == "test-api-key"
assert component.api_params == {"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 42}}
assert component.instructions == "test-instruction"
assert component.inputs == [("predicted_answers", List[str])]
@ -64,7 +65,7 @@ class TestLLMEvaluator:
},
],
)
assert component.generator.client.api_key == "test-api-key"
assert component._chat_generator.client.api_key == "test-api-key"
assert component.api_params == {"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 43}}
assert component.api == "openai"
assert component.examples == [
@ -239,7 +240,7 @@ class TestLLMEvaluator:
}
component = LLMEvaluator.from_dict(data)
assert component.api == "openai"
assert component.generator.client.api_key == "test-api-key"
assert component._chat_generator.client.api_key == "test-api-key"
assert component.api_params == {"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 42}}
assert component.instructions == "test-instruction"
assert component.inputs == [("predicted_answers", List[str])]
@ -317,10 +318,10 @@ class TestLLMEvaluator:
],
)
def generator_run(self, *args, **kwargs):
return {"replies": ['{"score": 0.5}']}
def chat_generator_run(self, *args, **kwargs):
return {"replies": [ChatMessage.from_assistant('{"score": 0.5}')]}
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
with pytest.raises(ValueError):
component.run(questions=["What is the capital of Germany?"], predicted_answers=[["Berlin"], ["Paris"]])
@ -342,10 +343,10 @@ class TestLLMEvaluator:
],
)
def generator_run(self, *args, **kwargs):
return {"replies": ['{"score": 0.5}']}
def chat_generator_run(self, *args, **kwargs):
return {"replies": [ChatMessage.from_assistant('{"score": 0.5}')]}
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
results = component.run(questions=["What is the capital of Germany?"], predicted_answers=["Berlin"])
assert results == {"results": [{"score": 0.5}], "meta": None}
@ -471,7 +472,7 @@ class TestLLMEvaluator:
},
],
)
assert component.generator.client.api_key == "test-api-key"
assert component._chat_generator.client.api_key == "test-api-key"
assert component.api_params == {
"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 42},
"api_base_url": "http://127.0.0.1:11434/v1",