From fe60eedee9a354c560b65a5738598be6ce1ff3aa Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Wed, 19 Jun 2024 13:47:38 +0200 Subject: [PATCH] fix: Fix deserialization of pipelines that contain `LLMEvaluator` subclasses (#7891) --- .../evaluators/context_relevance.py | 24 ++++++++++++--- .../components/evaluators/faithfulness.py | 26 +++++++++++++--- ...lass-deserialization-c633b2f95c84fe4b.yaml | 4 +++ .../test_context_relevance_evaluator.py | 26 ++++++++++++++++ .../evaluators/test_faithfulness_evaluator.py | 30 +++++++++++++++++++ 5 files changed, 102 insertions(+), 8 deletions(-) create mode 100644 releasenotes/notes/fix-llmevaluator-subclass-deserialization-c633b2f95c84fe4b.yaml diff --git a/haystack/components/evaluators/context_relevance.py b/haystack/components/evaluators/context_relevance.py index cbd861f86..e7a727a13 100644 --- a/haystack/components/evaluators/context_relevance.py +++ b/haystack/components/evaluators/context_relevance.py @@ -6,9 +6,8 @@ from typing import Any, Dict, List, Optional from numpy import mean as np_mean -from haystack import default_from_dict +from haystack import component, default_from_dict, default_to_dict from haystack.components.evaluators.llm_evaluator import LLMEvaluator -from haystack.core.component import component from haystack.utils import Secret, deserialize_secrets_inplace # Private global variable for default examples to include in the prompt if the user does not provide any examples @@ -34,6 +33,7 @@ _DEFAULT_EXAMPLES = [ ] +@component class ContextRelevanceEvaluator(LLMEvaluator): """ Evaluator that checks if a provided context is relevant to the question. @@ -121,7 +121,7 @@ class ContextRelevanceEvaluator(LLMEvaluator): self.api = api self.api_key = api_key - super().__init__( + super(ContextRelevanceEvaluator, self).__init__( instructions=self.instructions, inputs=self.inputs, outputs=self.outputs, @@ -147,7 +147,7 @@ class ContextRelevanceEvaluator(LLMEvaluator): - `individual_scores`: A list of context relevance scores for each input question. - `results`: A list of dictionaries with `statements` and `statement_scores` for each input context. """ - result = super().run(questions=questions, contexts=contexts) + result = super(ContextRelevanceEvaluator, self).run(questions=questions, contexts=contexts) # calculate average statement relevance score per query for idx, res in enumerate(result["results"]): @@ -165,6 +165,22 @@ class ContextRelevanceEvaluator(LLMEvaluator): return result + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + A dictionary with serialized data. + """ + return default_to_dict( + self, + api=self.api, + api_key=self.api_key.to_dict() if self.api_key else None, + examples=self.examples, + progress_bar=self.progress_bar, + raise_on_failure=self.raise_on_failure, + ) + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ContextRelevanceEvaluator": """ diff --git a/haystack/components/evaluators/faithfulness.py b/haystack/components/evaluators/faithfulness.py index 8455cd618..25b386734 100644 --- a/haystack/components/evaluators/faithfulness.py +++ b/haystack/components/evaluators/faithfulness.py @@ -6,9 +6,8 @@ from typing import Any, Dict, List, Optional from numpy import mean as np_mean -from haystack import default_from_dict +from haystack import component, default_from_dict, default_to_dict from haystack.components.evaluators.llm_evaluator import LLMEvaluator -from haystack.core.component import component from haystack.utils import Secret, deserialize_secrets_inplace # Default examples to include in the prompt if the user does not provide any examples @@ -46,6 +45,7 @@ _DEFAULT_EXAMPLES = [ ] +@component class FaithfulnessEvaluator(LLMEvaluator): """ Evaluator that checks if a generated answer can be inferred from the provided contexts. @@ -134,7 +134,7 @@ class FaithfulnessEvaluator(LLMEvaluator): self.api = api self.api_key = api_key - super().__init__( + super(FaithfulnessEvaluator, self).__init__( instructions=self.instructions, inputs=self.inputs, outputs=self.outputs, @@ -162,7 +162,9 @@ class FaithfulnessEvaluator(LLMEvaluator): - `individual_scores`: A list of faithfulness scores for each input answer. - `results`: A list of dictionaries with `statements` and `statement_scores` for each input answer. """ - result = super().run(questions=questions, contexts=contexts, predicted_answers=predicted_answers) + result = super(FaithfulnessEvaluator, self).run( + questions=questions, contexts=contexts, predicted_answers=predicted_answers + ) # calculate average statement faithfulness score per query for idx, res in enumerate(result["results"]): @@ -180,6 +182,22 @@ class FaithfulnessEvaluator(LLMEvaluator): return result + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + A dictionary with serialized data. + """ + return default_to_dict( + self, + api=self.api, + api_key=self.api_key.to_dict() if self.api_key else None, + examples=self.examples, + progress_bar=self.progress_bar, + raise_on_failure=self.raise_on_failure, + ) + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "FaithfulnessEvaluator": """ diff --git a/releasenotes/notes/fix-llmevaluator-subclass-deserialization-c633b2f95c84fe4b.yaml b/releasenotes/notes/fix-llmevaluator-subclass-deserialization-c633b2f95c84fe4b.yaml new file mode 100644 index 000000000..8a2adbca7 --- /dev/null +++ b/releasenotes/notes/fix-llmevaluator-subclass-deserialization-c633b2f95c84fe4b.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fix the deserialization of pipelines containing evaluator components that were subclasses of `LLMEvaluator`. diff --git a/test/components/evaluators/test_context_relevance_evaluator.py b/test/components/evaluators/test_context_relevance_evaluator.py index 2db69004d..f8dbfa2f0 100644 --- a/test/components/evaluators/test_context_relevance_evaluator.py +++ b/test/components/evaluators/test_context_relevance_evaluator.py @@ -8,6 +8,7 @@ import math import pytest +from haystack import Pipeline from haystack.components.evaluators import ContextRelevanceEvaluator from haystack.utils.auth import Secret @@ -71,6 +72,27 @@ class TestContextRelevanceEvaluator: {"inputs": {"questions": "Football is the most popular sport."}, "outputs": {"custom_score": 0}}, ] + def test_to_dict_with_parameters(self, monkeypatch): + monkeypatch.setenv("ENV_VAR", "test-api-key") + component = ContextRelevanceEvaluator( + api="openai", + api_key=Secret.from_env_var("ENV_VAR"), + examples=[{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}], + raise_on_failure=False, + progress_bar=False, + ) + data = component.to_dict() + assert data == { + "type": "haystack.components.evaluators.context_relevance.ContextRelevanceEvaluator", + "init_parameters": { + "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, + "api": "openai", + "examples": [{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}], + "progress_bar": False, + "raise_on_failure": False, + }, + } + def test_from_dict(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") @@ -87,6 +109,10 @@ class TestContextRelevanceEvaluator: assert component.generator.client.api_key == "test-api-key" assert component.examples == [{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}] + pipeline = Pipeline() + pipeline.add_component("evaluator", component) + assert pipeline.loads(pipeline.dumps()) + def test_run_calculates_mean_score(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = ContextRelevanceEvaluator() diff --git a/test/components/evaluators/test_faithfulness_evaluator.py b/test/components/evaluators/test_faithfulness_evaluator.py index 5c32f8c06..c24df9458 100644 --- a/test/components/evaluators/test_faithfulness_evaluator.py +++ b/test/components/evaluators/test_faithfulness_evaluator.py @@ -8,6 +8,7 @@ from typing import List import numpy as np import pytest +from haystack import Pipeline from haystack.components.evaluators import FaithfulnessEvaluator from haystack.utils.auth import Secret @@ -91,6 +92,31 @@ class TestFaithfulnessEvaluator: {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"custom_score": 0}}, ] + def test_to_dict_with_parameters(self, monkeypatch): + monkeypatch.setenv("ENV_VAR", "test-api-key") + component = FaithfulnessEvaluator( + api="openai", + api_key=Secret.from_env_var("ENV_VAR"), + examples=[ + {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}} + ], + raise_on_failure=False, + progress_bar=False, + ) + data = component.to_dict() + assert data == { + "type": "haystack.components.evaluators.faithfulness.FaithfulnessEvaluator", + "init_parameters": { + "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, + "api": "openai", + "examples": [ + {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}} + ], + "progress_bar": False, + "raise_on_failure": False, + }, + } + def test_from_dict(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") @@ -111,6 +137,10 @@ class TestFaithfulnessEvaluator: {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}} ] + pipeline = Pipeline() + pipeline.add_component("evaluator", component) + assert pipeline.loads(pipeline.dumps()) + def test_run_calculates_mean_score(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = FaithfulnessEvaluator()