From e4cf460bf61836ef1be8c3bb6e3ed817aa63f33e Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 26 Mar 2025 15:38:56 +0100 Subject: [PATCH] refactor!: use Chat Generator in LLM evaluators (#9116) * use chatgenerator instead of generator * rename generator to _chat_generator * rm print * Update releasenotes/notes/llm-evaluators-chat-generator-bf930fa6db019714.yaml Co-authored-by: David S. Batista --------- Co-authored-by: David S. Batista --- .../components/evaluators/llm_evaluator.py | 20 +++++---- ...ators-chat-generator-bf930fa6db019714.yaml | 6 +++ .../test_context_relevance_evaluator.py | 35 +++++++-------- .../evaluators/test_faithfulness_evaluator.py | 43 +++++++++++-------- .../evaluators/test_llm_evaluator.py | 21 ++++----- 5 files changed, 72 insertions(+), 53 deletions(-) create mode 100644 releasenotes/notes/llm-evaluators-chat-generator-bf930fa6db019714.yaml diff --git a/haystack/components/evaluators/llm_evaluator.py b/haystack/components/evaluators/llm_evaluator.py index ac1aab7c8..da6d089b3 100644 --- a/haystack/components/evaluators/llm_evaluator.py +++ b/haystack/components/evaluators/llm_evaluator.py @@ -9,7 +9,8 @@ from tqdm import tqdm from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.builders import PromptBuilder -from haystack.components.generators import OpenAIGenerator +from haystack.components.generators.chat.openai import OpenAIChatGenerator +from haystack.dataclasses.chat_message import ChatMessage from haystack.utils import Secret, deserialize_secrets_inplace, deserialize_type, serialize_type logger = logging.getLogger(__name__) @@ -110,7 +111,7 @@ class LLMEvaluator: generator_kwargs = {**self.api_params} if api_key: generator_kwargs["api_key"] = api_key - self.generator = OpenAIGenerator(**generator_kwargs) + self._chat_generator = OpenAIChatGenerator(**generator_kwargs) else: raise ValueError(f"Unsupported API: {api}") @@ -200,12 +201,13 @@ class LLMEvaluator: list_of_input_names_to_values = [dict(zip(input_names, v)) for v in values] results: List[Optional[Dict[str, Any]]] = [] - metadata = None + metadata = [] errors = 0 for input_names_to_values in tqdm(list_of_input_names_to_values, disable=not self.progress_bar): prompt = self.builder.run(**input_names_to_values) + messages = [ChatMessage.from_user(prompt["prompt"])] try: - result = self.generator.run(prompt=prompt["prompt"]) + result = self._chat_generator.run(messages=messages) except Exception as e: if self.raise_on_failure: raise ValueError(f"Error while generating response for prompt: {prompt}. Error: {e}") @@ -214,15 +216,15 @@ class LLMEvaluator: errors += 1 continue - if self.is_valid_json_and_has_expected_keys(expected=self.outputs, received=result["replies"][0]): - parsed_result = json.loads(result["replies"][0]) + if self.is_valid_json_and_has_expected_keys(expected=self.outputs, received=result["replies"][0].text): + parsed_result = json.loads(result["replies"][0].text) results.append(parsed_result) else: results.append(None) errors += 1 - if self.api == "openai" and "meta" in result: - metadata = result["meta"] + if result["replies"][0].meta: + metadata.append(result["replies"][0].meta) if errors > 0: logger.warning( @@ -231,7 +233,7 @@ class LLMEvaluator: len=len(list_of_input_names_to_values), ) - return {"results": results, "meta": metadata} + return {"results": results, "meta": metadata or None} def prepare_template(self) -> str: """ diff --git a/releasenotes/notes/llm-evaluators-chat-generator-bf930fa6db019714.yaml b/releasenotes/notes/llm-evaluators-chat-generator-bf930fa6db019714.yaml new file mode 100644 index 000000000..0185a1fb5 --- /dev/null +++ b/releasenotes/notes/llm-evaluators-chat-generator-bf930fa6db019714.yaml @@ -0,0 +1,6 @@ +--- +upgrade: + - | + `LLMEvaluator`, `ContextRelevanceEvaluator`, and `FaithfulnessEvaluator` now internally use a + `ChatGenerator` instance instead of a `Generator` instance. + The public attribute `generator` has been replaced with `_chat_generator`. diff --git a/test/components/evaluators/test_context_relevance_evaluator.py b/test/components/evaluators/test_context_relevance_evaluator.py index 3e33a9f4c..78d7506c4 100644 --- a/test/components/evaluators/test_context_relevance_evaluator.py +++ b/test/components/evaluators/test_context_relevance_evaluator.py @@ -11,6 +11,7 @@ import pytest from haystack import Pipeline from haystack.components.evaluators import ContextRelevanceEvaluator from haystack.utils.auth import Secret +from haystack.dataclasses.chat_message import ChatMessage class TestContextRelevanceEvaluator: @@ -18,7 +19,7 @@ class TestContextRelevanceEvaluator: monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = ContextRelevanceEvaluator() assert component.api == "openai" - assert component.generator.client.api_key == "test-api-key" + assert component._chat_generator.client.api_key == "test-api-key" assert component.instructions == ( "Please extract only sentences from the provided context which are absolutely relevant and " "required to answer the following question. If no relevant sentences are found, or if you " @@ -65,7 +66,7 @@ class TestContextRelevanceEvaluator: {"inputs": {"questions": "Football is the most popular sport."}, "outputs": {"custom_score": 0}}, ], ) - assert component.generator.client.api_key == "test-api-key" + assert component._chat_generator.client.api_key == "test-api-key" assert component.api == "openai" assert component.examples == [ {"inputs": {"questions": "Damn, this is straight outta hell!!!"}, "outputs": {"custom_score": 1}}, @@ -107,7 +108,7 @@ class TestContextRelevanceEvaluator: } component = ContextRelevanceEvaluator.from_dict(data) assert component.api == "openai" - assert component.generator.client.api_key == "test-api-key" + assert component._chat_generator.client.api_key == "test-api-key" assert component.examples == [{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}] pipeline = Pipeline() @@ -118,13 +119,13 @@ class TestContextRelevanceEvaluator: monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = ContextRelevanceEvaluator() - def generator_run(self, *args, **kwargs): - if "Football" in kwargs["prompt"]: - return {"replies": ['{"relevant_statements": ["a", "b"], "score": 1}']} + def chat_generator_run(self, *args, **kwargs): + if "Football" in kwargs["messages"][0].text: + return {"replies": [ChatMessage.from_assistant('{"relevant_statements": ["a", "b"], "score": 1}')]} else: - return {"replies": ['{"relevant_statements": [], "score": 0}']} + return {"replies": [ChatMessage.from_assistant('{"relevant_statements": [], "score": 0}')]} - monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run) + monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run) questions = ["Which is the most popular global sport?", "Who created the Python language?"] contexts = [ @@ -152,13 +153,13 @@ class TestContextRelevanceEvaluator: monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = ContextRelevanceEvaluator() - def generator_run(self, *args, **kwargs): - if "Football" in kwargs["prompt"]: - return {"replies": ['{"relevant_statements": ["a", "b"], "score": 1}']} + def chat_generator_run(self, *args, **kwargs): + if "Football" in kwargs["messages"][0].text: + return {"replies": [ChatMessage.from_assistant('{"relevant_statements": ["a", "b"], "score": 1}')]} else: - return {"replies": ['{"relevant_statements": [], "score": 0}']} + return {"replies": [ChatMessage.from_assistant('{"relevant_statements": [], "score": 0}')]} - monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run) + monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run) questions = ["Which is the most popular global sport?", "Who created the Python language?"] contexts = [ @@ -188,13 +189,13 @@ class TestContextRelevanceEvaluator: monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = ContextRelevanceEvaluator(raise_on_failure=False) - def generator_run(self, *args, **kwargs): - if "Python" in kwargs["prompt"]: + def chat_generator_run(self, *args, **kwargs): + if "Python" in kwargs["messages"][0].text: raise Exception("OpenAI API request failed.") else: - return {"replies": ['{"relevant_statements": ["c", "d"], "score": 1}']} + return {"replies": [ChatMessage.from_assistant('{"relevant_statements": ["c", "d"], "score": 1}')]} - monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run) + monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run) questions = ["Which is the most popular global sport?", "Who created the Python language?"] contexts = [ diff --git a/test/components/evaluators/test_faithfulness_evaluator.py b/test/components/evaluators/test_faithfulness_evaluator.py index de92388ec..f92d280db 100644 --- a/test/components/evaluators/test_faithfulness_evaluator.py +++ b/test/components/evaluators/test_faithfulness_evaluator.py @@ -10,6 +10,7 @@ import pytest from haystack import Pipeline from haystack.components.evaluators import FaithfulnessEvaluator from haystack.utils.auth import Secret +from haystack.dataclasses.chat_message import ChatMessage class TestFaithfulnessEvaluator: @@ -17,7 +18,7 @@ class TestFaithfulnessEvaluator: monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = FaithfulnessEvaluator() assert component.api == "openai" - assert component.generator.client.api_key == "test-api-key" + assert component._chat_generator.client.api_key == "test-api-key" assert component.instructions == ( "Your task is to judge the faithfulness or groundedness of statements based " "on context information. First, please extract statements from a provided predicted " @@ -84,7 +85,7 @@ class TestFaithfulnessEvaluator: }, ], ) - assert component.generator.client.api_key == "test-api-key" + assert component._chat_generator.client.api_key == "test-api-key" assert component.api == "openai" assert component.examples == [ {"inputs": {"predicted_answers": "Damn, this is straight outta hell!!!"}, "outputs": {"custom_score": 1}}, @@ -132,7 +133,7 @@ class TestFaithfulnessEvaluator: } component = FaithfulnessEvaluator.from_dict(data) assert component.api == "openai" - assert component.generator.client.api_key == "test-api-key" + assert component._chat_generator.client.api_key == "test-api-key" assert component.examples == [ {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}} ] @@ -145,13 +146,17 @@ class TestFaithfulnessEvaluator: monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = FaithfulnessEvaluator() - def generator_run(self, *args, **kwargs): - if "Football" in kwargs["prompt"]: - return {"replies": ['{"statements": ["a", "b"], "statement_scores": [1, 0]}']} + def chat_generator_run(self, *args, **kwargs): + if "Football" in kwargs["messages"][0].text: + return { + "replies": [ChatMessage.from_assistant('{"statements": ["a", "b"], "statement_scores": [1, 0]}')] + } else: - return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']} + return { + "replies": [ChatMessage.from_assistant('{"statements": ["c", "d"], "statement_scores": [1, 1]}')] + } - monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run) + monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run) questions = ["Which is the most popular global sport?", "Who created the Python language?"] contexts = [ @@ -186,13 +191,15 @@ class TestFaithfulnessEvaluator: monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = FaithfulnessEvaluator() - def generator_run(self, *args, **kwargs): - if "Football" in kwargs["prompt"]: - return {"replies": ['{"statements": ["a", "b"], "statement_scores": [1, 0]}']} + def chat_generator_run(self, *args, **kwargs): + if "Football" in kwargs["messages"][0].text: + return { + "replies": [ChatMessage.from_assistant('{"statements": ["a", "b"], "statement_scores": [1, 0]}')] + } else: - return {"replies": ['{"statements": [], "statement_scores": []}']} + return {"replies": [ChatMessage.from_assistant('{"statements": [], "statement_scores": []}')]} - monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run) + monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run) questions = ["Which is the most popular global sport?", "Who created the Python language?"] contexts = [ @@ -229,13 +236,15 @@ class TestFaithfulnessEvaluator: monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = FaithfulnessEvaluator(raise_on_failure=False) - def generator_run(self, *args, **kwargs): - if "Python" in kwargs["prompt"]: + def chat_generator_run(self, *args, **kwargs): + if "Python" in kwargs["messages"][0].text: raise Exception("OpenAI API request failed.") else: - return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']} + return { + "replies": [ChatMessage.from_assistant('{"statements": ["c", "d"], "statement_scores": [1, 1]}')] + } - monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run) + monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run) questions = ["Which is the most popular global sport?", "Who created the Python language?"] contexts = [ diff --git a/test/components/evaluators/test_llm_evaluator.py b/test/components/evaluators/test_llm_evaluator.py index e0feb6c0f..b9fa900ff 100644 --- a/test/components/evaluators/test_llm_evaluator.py +++ b/test/components/evaluators/test_llm_evaluator.py @@ -9,6 +9,7 @@ import pytest from haystack import Pipeline from haystack.components.evaluators import LLMEvaluator from haystack.utils.auth import Secret +from haystack.dataclasses.chat_message import ChatMessage class TestLLMEvaluator: @@ -23,7 +24,7 @@ class TestLLMEvaluator: ], ) assert component.api == "openai" - assert component.generator.client.api_key == "test-api-key" + assert component._chat_generator.client.api_key == "test-api-key" assert component.api_params == {"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 42}} assert component.instructions == "test-instruction" assert component.inputs == [("predicted_answers", List[str])] @@ -64,7 +65,7 @@ class TestLLMEvaluator: }, ], ) - assert component.generator.client.api_key == "test-api-key" + assert component._chat_generator.client.api_key == "test-api-key" assert component.api_params == {"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 43}} assert component.api == "openai" assert component.examples == [ @@ -239,7 +240,7 @@ class TestLLMEvaluator: } component = LLMEvaluator.from_dict(data) assert component.api == "openai" - assert component.generator.client.api_key == "test-api-key" + assert component._chat_generator.client.api_key == "test-api-key" assert component.api_params == {"generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 42}} assert component.instructions == "test-instruction" assert component.inputs == [("predicted_answers", List[str])] @@ -317,10 +318,10 @@ class TestLLMEvaluator: ], ) - def generator_run(self, *args, **kwargs): - return {"replies": ['{"score": 0.5}']} + def chat_generator_run(self, *args, **kwargs): + return {"replies": [ChatMessage.from_assistant('{"score": 0.5}')]} - monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run) + monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run) with pytest.raises(ValueError): component.run(questions=["What is the capital of Germany?"], predicted_answers=[["Berlin"], ["Paris"]]) @@ -342,10 +343,10 @@ class TestLLMEvaluator: ], ) - def generator_run(self, *args, **kwargs): - return {"replies": ['{"score": 0.5}']} + def chat_generator_run(self, *args, **kwargs): + return {"replies": [ChatMessage.from_assistant('{"score": 0.5}')]} - monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run) + monkeypatch.setattr("haystack.components.evaluators.llm_evaluator.OpenAIChatGenerator.run", chat_generator_run) results = component.run(questions=["What is the capital of Germany?"], predicted_answers=["Berlin"]) assert results == {"results": [{"score": 0.5}], "meta": None} @@ -471,7 +472,7 @@ class TestLLMEvaluator: }, ], ) - assert component.generator.client.api_key == "test-api-key" + assert component._chat_generator.client.api_key == "test-api-key" assert component.api_params == { "generation_kwargs": {"response_format": {"type": "json_object"}, "seed": 42}, "api_base_url": "http://127.0.0.1:11434/v1",