mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-30 08:56:16 +00:00
refactor!: use Chat Generator in LLM evaluators (#9116)
* use chatgenerator instead of generator * rename generator to _chat_generator * rm print * Update releasenotes/notes/llm-evaluators-chat-generator-bf930fa6db019714.yaml Co-authored-by: David S. Batista <dsbatista@gmail.com> --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
parent
13941d8bd9
commit
e4cf460bf6
@ -9,7 +9,8 @@ from tqdm import tqdm
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.components.builders import PromptBuilder
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator
|
||||
from haystack.dataclasses.chat_message import ChatMessage
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace, deserialize_type, serialize_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -110,7 +111,7 @@ class LLMEvaluator:
|
||||
generator_kwargs = {**self.api_params}
|
||||
if api_key:
|
||||
generator_kwargs["api_key"] = api_key
|
||||
self.generator = OpenAIGenerator(**generator_kwargs)
|
||||
self._chat_generator = OpenAIChatGenerator(**generator_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported API: {api}")
|
||||
|
||||
@ -200,12 +201,13 @@ class LLMEvaluator:
|
||||
list_of_input_names_to_values = [dict(zip(input_names, v)) for v in values]
|
||||
|
||||
results: List[Optional[Dict[str, Any]]] = []
|
||||
metadata = None
|
||||
metadata = []
|
||||
errors = 0
|
||||
for input_names_to_values in tqdm(list_of_input_names_to_values, disable=not self.progress_bar):
|
||||
prompt = self.builder.run(**input_names_to_values)
|
||||
messages = [ChatMessage.from_user(prompt["prompt"])]
|
||||
try:
|
||||
result = self.generator.run(prompt=prompt["prompt"])
|
||||
result = self._chat_generator.run(messages=messages)
|
||||
except Exception as e:
|
||||
if self.raise_on_failure:
|
||||
raise ValueError(f"Error while generating response for prompt: {prompt}. Error: {e}")
|
||||
@ -214,15 +216,15 @@ class LLMEvaluator:
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
if self.is_valid_json_and_has_expected_keys(expected=self.outputs, received=result["replies"][0]):
|
||||
parsed_result = json.loads(result["replies"][0])
|
||||
if self.is_valid_json_and_has_expected_keys(expected=self.outputs, received=result["replies"][0].text):
|
||||
parsed_result = json.loads(result["replies"][0].text)
|
||||
results.append(parsed_result)
|
||||
else:
|
||||
results.append(None)
|
||||
errors += 1
|
||||
|
||||
if self.api == "openai" and "meta" in result:
|
||||
metadata = result["meta"]
|
||||
if result["replies"][0].meta:
|
||||
metadata.append(result["replies"][0].meta)
|
||||
|
||||
if errors > 0:
|
||||
logger.warning(
|
||||
@ -231,7 +233,7 @@ class LLMEvaluator:
|
||||
len=len(list_of_input_names_to_values),
|
||||
)
|
||||
|
||||
return {"results": results, "meta": metadata}
|
||||
return {"results": results, "meta": metadata or None}
|
||||
|
||||
def prepare_template(self) -> str:
|
||||
"""
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
upgrade:
|
||||
- |
|
||||
`LLMEvaluator`, `ContextRelevanceEvaluator`, and `FaithfulnessEvaluator` now internally use a
|
||||
`ChatGenerator` instance instead of a `Generator` instance.
|
||||
The public attribute `generator` has been replaced with `_chat_generator`.
|
||||
@ -11,6 +11,7 @@ import pytest
|
||||
from haystack import Pipeline
|
||||
from haystack.components.evaluators import ContextRelevanceEvaluator
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.dataclasses.chat_message import ChatMessage
|
||||
|
||||
|
||||
class TestContextRelevanceEvaluator:
|
||||
@ -18,7 +19,7 @@ class TestContextRelevanceEvaluator:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = ContextRelevanceEvaluator()
|
||||
assert component.api == "openai"
|
||||
assert component.generator.client.api_key == "test-api-key"
|
||||
assert component._chat_generator.client.api_key == "test-api-key"
|
||||
assert component.instructions == (
|
||||
"Please extract only sentences from the provided context which are absolutely relevant and "
|
||||
"required to answer the following question. If no relevant sentences are found, or if you "
|
||||
@ -65,7 +66,7 @@ class TestContextRelevanceEvaluator:
|
||||
{"inputs": {"questions": "Football is the most popular sport."}, "outputs": {"custom_score": 0}},
|
||||
],
|
||||
)
|
||||
assert component.generator.client.api_key == "test-api-key"
|
||||
assert component._chat_generator.client.api_key == "test-api-key"
|
||||
assert component.api == "openai"
|
||||
assert component.examples == [
|
||||
{"inputs": {"questions": "Damn, this is straight outta hell!!!"}, "outputs": {"custom_score": 1}},
|
||||
@ -107,7 +108,7 @@ class TestContextRelevanceEvaluator:
|
||||
}
|
||||
component = ContextRelevanceEvaluator.from_dict(data)
|
||||
assert component.api == "openai"
|
||||
assert component.generator.client.api_key == "test-api-key"
|
||||
assert component._chat_generator.client.api_key == "test-api-key"
|
||||
assert component.examples == [{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}]
|
||||
|
||||
pipeline = Pipeline()
|
||||
@ -118,13 +119,13 @@ class TestContextRelevanceEvaluator:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = ContextRelevanceEvaluator()
|
||||
|
||||
def generator_run(self, *args, **kwargs):
|
||||
if "Football" in kwargs["prompt"]:
|
||||
return {"replies": ['{"relevant_statements": ["a", "b"], "score": 1}']}
|
||||
def chat_generator_run(self, *args, **kwargs):
|
||||
if "Football" in kwargs["messages"][0].text:
|
||||
return {"replies": [ChatMessage.from_assistant('{"relevant_statements": ["a", "b"], "score": 1}')]}
|
||||
else:
|
||||
return {"replies": ['{"relevant_statements": [], "score": 0}']}
|
||||
return {"replies": [ChatMessage.from_assistant('{"relevant_statements": [], "score": 0}')]}
|
||||
|
||||
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
|
||||
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
|
||||
|
||||
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
|
||||
contexts = [
|
||||
@ -152,13 +153,13 @@ class TestContextRelevanceEvaluator:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = ContextRelevanceEvaluator()
|
||||
|
||||
def generator_run(self, *args, **kwargs):
|
||||
if "Football" in kwargs["prompt"]:
|
||||
return {"replies": ['{"relevant_statements": ["a", "b"], "score": 1}']}
|
||||
def chat_generator_run(self, *args, **kwargs):
|
||||
if "Football" in kwargs["messages"][0].text:
|
||||
return {"replies": [ChatMessage.from_assistant('{"relevant_statements": ["a", "b"], "score": 1}')]}
|
||||
else:
|
||||
return {"replies": ['{"relevant_statements": [], "score": 0}']}
|
||||
return {"replies": [ChatMessage.from_assistant('{"relevant_statements": [], "score": 0}')]}
|
||||
|
||||
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
|
||||
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
|
||||
|
||||
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
|
||||
contexts = [
|
||||
@ -188,13 +189,13 @@ class TestContextRelevanceEvaluator:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = ContextRelevanceEvaluator(raise_on_failure=False)
|
||||
|
||||
def generator_run(self, *args, **kwargs):
|
||||
if "Python" in kwargs["prompt"]:
|
||||
def chat_generator_run(self, *args, **kwargs):
|
||||
if "Python" in kwargs["messages"][0].text:
|
||||
raise Exception("OpenAI API request failed.")
|
||||
else:
|
||||
return {"replies": ['{"relevant_statements": ["c", "d"], "score": 1}']}
|
||||
return {"replies": [ChatMessage.from_assistant('{"relevant_statements": ["c", "d"], "score": 1}')]}
|
||||
|
||||
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
|
||||
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
|
||||
|
||||
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
|
||||
contexts = [
|
||||
|
||||
@ -10,6 +10,7 @@ import pytest
|
||||
from haystack import Pipeline
|
||||
from haystack.components.evaluators import FaithfulnessEvaluator
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.dataclasses.chat_message import ChatMessage
|
||||
|
||||
|
||||
class TestFaithfulnessEvaluator:
|
||||
@ -17,7 +18,7 @@ class TestFaithfulnessEvaluator:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = FaithfulnessEvaluator()
|
||||
assert component.api == "openai"
|
||||
assert component.generator.client.api_key == "test-api-key"
|
||||
assert component._chat_generator.client.api_key == "test-api-key"
|
||||
assert component.instructions == (
|
||||
"Your task is to judge the faithfulness or groundedness of statements based "
|
||||
"on context information. First, please extract statements from a provided predicted "
|
||||
@ -84,7 +85,7 @@ class TestFaithfulnessEvaluator:
|
||||
},
|
||||
],
|
||||
)
|
||||
assert component.generator.client.api_key == "test-api-key"
|
||||
assert component._chat_generator.client.api_key == "test-api-key"
|
||||
assert component.api == "openai"
|
||||
assert component.examples == [
|
||||
{"inputs": {"predicted_answers": "Damn, this is straight outta hell!!!"}, "outputs": {"custom_score": 1}},
|
||||
@ -132,7 +133,7 @@ class TestFaithfulnessEvaluator:
|
||||
}
|
||||
component = FaithfulnessEvaluator.from_dict(data)
|
||||
assert component.api == "openai"
|
||||
assert component.generator.client.api_key == "test-api-key"
|
||||
assert component._chat_generator.client.api_key == "test-api-key"
|
||||
assert component.examples == [
|
||||
{"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}}
|
||||
]
|
||||
@ -145,13 +146,17 @@ class TestFaithfulnessEvaluator:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = FaithfulnessEvaluator()
|
||||
|
||||
def generator_run(self, *args, **kwargs):
|
||||
if "Football" in kwargs["prompt"]:
|
||||
return {"replies": ['{"statements": ["a", "b"], "statement_scores": [1, 0]}']}
|
||||
def chat_generator_run(self, *args, **kwargs):
|
||||
if "Football" in kwargs["messages"][0].text:
|
||||
return {
|
||||
"replies": [ChatMessage.from_assistant('{"statements": ["a", "b"], "statement_scores": [1, 0]}')]
|
||||
}
|
||||
else:
|
||||
return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']}
|
||||
return {
|
||||
"replies": [ChatMessage.from_assistant('{"statements": ["c", "d"], "statement_scores": [1, 1]}')]
|
||||
}
|
||||
|
||||
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
|
||||
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
|
||||
|
||||
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
|
||||
contexts = [
|
||||
@ -186,13 +191,15 @@ class TestFaithfulnessEvaluator:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = FaithfulnessEvaluator()
|
||||
|
||||
def generator_run(self, *args, **kwargs):
|
||||
if "Football" in kwargs["prompt"]:
|
||||
return {"replies": ['{"statements": ["a", "b"], "statement_scores": [1, 0]}']}
|
||||
def chat_generator_run(self, *args, **kwargs):
|
||||
if "Football" in kwargs["messages"][0].text:
|
||||
return {
|
||||
"replies": [ChatMessage.from_assistant('{"statements": ["a", "b"], "statement_scores": [1, 0]}')]
|
||||
}
|
||||
else:
|
||||
return {"replies": ['{"statements": [], "statement_scores": []}']}
|
||||
return {"replies": [ChatMessage.from_assistant('{"statements": [], "statement_scores": []}')]}
|
||||
|
||||
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
|
||||
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
|
||||
|
||||
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
|
||||
contexts = [
|
||||
@ -229,13 +236,15 @@ class TestFaithfulnessEvaluator:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = FaithfulnessEvaluator(raise_on_failure=False)
|
||||
|
||||
def generator_run(self, *args, **kwargs):
|
||||
if "Python" in kwargs["prompt"]:
|
||||
def chat_generator_run(self, *args, **kwargs):
|
||||
if "Python" in kwargs["messages"][0].text:
|
||||
raise Exception("OpenAI API request failed.")
|
||||
else:
|
||||
return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']}
|
||||
return {
|
||||
"replies": [ChatMessage.from_assistant('{"statements": ["c", "d"], "statement_scores": [1, 1]}')]
|
||||
}
|
||||
|
||||
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
|
||||
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
|
||||
|
||||
questions = ["Which is the most popular global sport?", "Who created the Python language?"]
|
||||
contexts = [
|
||||
|
||||
@ -9,6 +9,7 @@ import pytest
|
||||
from haystack import Pipeline
|
||||
from haystack.components.evaluators import LLMEvaluator
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.dataclasses.chat_message import ChatMessage
|
||||
|
||||
|
||||
class TestLLMEvaluator:
|
||||
@ -23,7 +24,7 @@ class TestLLMEvaluator:
|
||||
],
|
||||
)
|
||||
assert component.api == "openai"
|
||||
assert component.generator.client.api_key == "test-api-key"
|
||||
assert component._chat_generator.client.api_key == "test-api-key"
|
||||
assert component.api_params == {"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 42}}
|
||||
assert component.instructions == "test-instruction"
|
||||
assert component.inputs == [("predicted_answers", List[str])]
|
||||
@ -64,7 +65,7 @@ class TestLLMEvaluator:
|
||||
},
|
||||
],
|
||||
)
|
||||
assert component.generator.client.api_key == "test-api-key"
|
||||
assert component._chat_generator.client.api_key == "test-api-key"
|
||||
assert component.api_params == {"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 43}}
|
||||
assert component.api == "openai"
|
||||
assert component.examples == [
|
||||
@ -239,7 +240,7 @@ class TestLLMEvaluator:
|
||||
}
|
||||
component = LLMEvaluator.from_dict(data)
|
||||
assert component.api == "openai"
|
||||
assert component.generator.client.api_key == "test-api-key"
|
||||
assert component._chat_generator.client.api_key == "test-api-key"
|
||||
assert component.api_params == {"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 42}}
|
||||
assert component.instructions == "test-instruction"
|
||||
assert component.inputs == [("predicted_answers", List[str])]
|
||||
@ -317,10 +318,10 @@ class TestLLMEvaluator:
|
||||
],
|
||||
)
|
||||
|
||||
def generator_run(self, *args, **kwargs):
|
||||
return {"replies": ['{"score": 0.5}']}
|
||||
def chat_generator_run(self, *args, **kwargs):
|
||||
return {"replies": [ChatMessage.from_assistant('{"score": 0.5}')]}
|
||||
|
||||
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
|
||||
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
component.run(questions=["What is the capital of Germany?"], predicted_answers=[["Berlin"], ["Paris"]])
|
||||
@ -342,10 +343,10 @@ class TestLLMEvaluator:
|
||||
],
|
||||
)
|
||||
|
||||
def generator_run(self, *args, **kwargs):
|
||||
return {"replies": ['{"score": 0.5}']}
|
||||
def chat_generator_run(self, *args, **kwargs):
|
||||
return {"replies": [ChatMessage.from_assistant('{"score": 0.5}')]}
|
||||
|
||||
monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)
|
||||
monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run)
|
||||
|
||||
results = component.run(questions=["What is the capital of Germany?"], predicted_answers=["Berlin"])
|
||||
assert results == {"results": [{"score": 0.5}], "meta": None}
|
||||
@ -471,7 +472,7 @@ class TestLLMEvaluator:
|
||||
},
|
||||
],
|
||||
)
|
||||
assert component.generator.client.api_key == "test-api-key"
|
||||
assert component._chat_generator.client.api_key == "test-api-key"
|
||||
assert component.api_params == {
|
||||
"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 42},
|
||||
"api_base_url": "http://127.0.0.1:11434/v1",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user