fix: Fix deserialization of pipelines that contain LLMEvaluator subclasses (#7891)

This commit is contained in:
Madeesh Kannan 2024-06-19 13:47:38 +02:00 committed by GitHub
parent 7c31d5f418
commit fe60eedee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 102 additions and 8 deletions

View File

@ -6,9 +6,8 @@ from typing import Any, Dict, List, Optional
from numpy import mean as np_mean from numpy import mean as np_mean
from haystack import default_from_dict from haystack import component, default_from_dict, default_to_dict
from haystack.components.evaluators.llm_evaluator import LLMEvaluator from haystack.components.evaluators.llm_evaluator import LLMEvaluator
from haystack.core.component import component
from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils import Secret, deserialize_secrets_inplace
# Private global variable for default examples to include in the prompt if the user does not provide any examples # Private global variable for default examples to include in the prompt if the user does not provide any examples
@ -34,6 +33,7 @@ _DEFAULT_EXAMPLES = [
] ]
@component
class ContextRelevanceEvaluator(LLMEvaluator): class ContextRelevanceEvaluator(LLMEvaluator):
""" """
Evaluator that checks if a provided context is relevant to the question. Evaluator that checks if a provided context is relevant to the question.
@ -121,7 +121,7 @@ class ContextRelevanceEvaluator(LLMEvaluator):
self.api = api self.api = api
self.api_key = api_key self.api_key = api_key
super().__init__( super(ContextRelevanceEvaluator, self).__init__(
instructions=self.instructions, instructions=self.instructions,
inputs=self.inputs, inputs=self.inputs,
outputs=self.outputs, outputs=self.outputs,
@ -147,7 +147,7 @@ class ContextRelevanceEvaluator(LLMEvaluator):
- `individual_scores`: A list of context relevance scores for each input question. - `individual_scores`: A list of context relevance scores for each input question.
- `results`: A list of dictionaries with `statements` and `statement_scores` for each input context. - `results`: A list of dictionaries with `statements` and `statement_scores` for each input context.
""" """
result = super().run(questions=questions, contexts=contexts) result = super(ContextRelevanceEvaluator, self).run(questions=questions, contexts=contexts)
# calculate average statement relevance score per query # calculate average statement relevance score per query
for idx, res in enumerate(result["results"]): for idx, res in enumerate(result["results"]):
@ -165,6 +165,22 @@ class ContextRelevanceEvaluator(LLMEvaluator):
return result return result
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:returns:
A dictionary with serialized data.
"""
return default_to_dict(
self,
api=self.api,
api_key=self.api_key.to_dict() if self.api_key else None,
examples=self.examples,
progress_bar=self.progress_bar,
raise_on_failure=self.raise_on_failure,
)
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ContextRelevanceEvaluator": def from_dict(cls, data: Dict[str, Any]) -> "ContextRelevanceEvaluator":
""" """

View File

@ -6,9 +6,8 @@ from typing import Any, Dict, List, Optional
from numpy import mean as np_mean from numpy import mean as np_mean
from haystack import default_from_dict from haystack import component, default_from_dict, default_to_dict
from haystack.components.evaluators.llm_evaluator import LLMEvaluator from haystack.components.evaluators.llm_evaluator import LLMEvaluator
from haystack.core.component import component
from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils import Secret, deserialize_secrets_inplace
# Default examples to include in the prompt if the user does not provide any examples # Default examples to include in the prompt if the user does not provide any examples
@ -46,6 +45,7 @@ _DEFAULT_EXAMPLES = [
] ]
@component
class FaithfulnessEvaluator(LLMEvaluator): class FaithfulnessEvaluator(LLMEvaluator):
""" """
Evaluator that checks if a generated answer can be inferred from the provided contexts. Evaluator that checks if a generated answer can be inferred from the provided contexts.
@ -134,7 +134,7 @@ class FaithfulnessEvaluator(LLMEvaluator):
self.api = api self.api = api
self.api_key = api_key self.api_key = api_key
super().__init__( super(FaithfulnessEvaluator, self).__init__(
instructions=self.instructions, instructions=self.instructions,
inputs=self.inputs, inputs=self.inputs,
outputs=self.outputs, outputs=self.outputs,
@ -162,7 +162,9 @@ class FaithfulnessEvaluator(LLMEvaluator):
- `individual_scores`: A list of faithfulness scores for each input answer. - `individual_scores`: A list of faithfulness scores for each input answer.
- `results`: A list of dictionaries with `statements` and `statement_scores` for each input answer. - `results`: A list of dictionaries with `statements` and `statement_scores` for each input answer.
""" """
result = super().run(questions=questions, contexts=contexts, predicted_answers=predicted_answers) result = super(FaithfulnessEvaluator, self).run(
questions=questions, contexts=contexts, predicted_answers=predicted_answers
)
# calculate average statement faithfulness score per query # calculate average statement faithfulness score per query
for idx, res in enumerate(result["results"]): for idx, res in enumerate(result["results"]):
@ -180,6 +182,22 @@ class FaithfulnessEvaluator(LLMEvaluator):
return result return result
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:returns:
A dictionary with serialized data.
"""
return default_to_dict(
self,
api=self.api,
api_key=self.api_key.to_dict() if self.api_key else None,
examples=self.examples,
progress_bar=self.progress_bar,
raise_on_failure=self.raise_on_failure,
)
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FaithfulnessEvaluator": def from_dict(cls, data: Dict[str, Any]) -> "FaithfulnessEvaluator":
""" """

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fix the deserialization of pipelines containing evaluator components that were subclasses of `LLMEvaluator`.

View File

@ -8,6 +8,7 @@ import math
import pytest import pytest
from haystack import Pipeline
from haystack.components.evaluators import ContextRelevanceEvaluator from haystack.components.evaluators import ContextRelevanceEvaluator
from haystack.utils.auth import Secret from haystack.utils.auth import Secret
@ -71,6 +72,27 @@ class TestContextRelevanceEvaluator:
{"inputs": {"questions": "Football is the most popular sport."}, "outputs": {"custom_score": 0}}, {"inputs": {"questions": "Football is the most popular sport."}, "outputs": {"custom_score": 0}},
] ]
def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = ContextRelevanceEvaluator(
api="openai",
api_key=Secret.from_env_var("ENV_VAR"),
examples=[{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}],
raise_on_failure=False,
progress_bar=False,
)
data = component.to_dict()
assert data == {
"type": "haystack.components.evaluators.context_relevance.ContextRelevanceEvaluator",
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"},
"api": "openai",
"examples": [{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}],
"progress_bar": False,
"raise_on_failure": False,
},
}
def test_from_dict(self, monkeypatch): def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
@ -87,6 +109,10 @@ class TestContextRelevanceEvaluator:
assert component.generator.client.api_key == "test-api-key" assert component.generator.client.api_key == "test-api-key"
assert component.examples == [{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}] assert component.examples == [{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}]
pipeline = Pipeline()
pipeline.add_component("evaluator", component)
assert pipeline.loads(pipeline.dumps())
def test_run_calculates_mean_score(self, monkeypatch): def test_run_calculates_mean_score(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = ContextRelevanceEvaluator() component = ContextRelevanceEvaluator()

View File

@ -8,6 +8,7 @@ from typing import List
import numpy as np import numpy as np
import pytest import pytest
from haystack import Pipeline
from haystack.components.evaluators import FaithfulnessEvaluator from haystack.components.evaluators import FaithfulnessEvaluator
from haystack.utils.auth import Secret from haystack.utils.auth import Secret
@ -91,6 +92,31 @@ class TestFaithfulnessEvaluator:
{"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"custom_score": 0}}, {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"custom_score": 0}},
] ]
def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = FaithfulnessEvaluator(
api="openai",
api_key=Secret.from_env_var("ENV_VAR"),
examples=[
{"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}}
],
raise_on_failure=False,
progress_bar=False,
)
data = component.to_dict()
assert data == {
"type": "haystack.components.evaluators.faithfulness.FaithfulnessEvaluator",
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"},
"api": "openai",
"examples": [
{"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}}
],
"progress_bar": False,
"raise_on_failure": False,
},
}
def test_from_dict(self, monkeypatch): def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
@ -111,6 +137,10 @@ class TestFaithfulnessEvaluator:
{"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}} {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}}
] ]
pipeline = Pipeline()
pipeline.add_component("evaluator", component)
assert pipeline.loads(pipeline.dumps())
def test_run_calculates_mean_score(self, monkeypatch): def test_run_calculates_mean_score(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = FaithfulnessEvaluator() component = FaithfulnessEvaluator()