mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
fix: typo in sas_evaluator arg (#7486)
* fixing typo on SAS arg * fixing tests * fixing tests
This commit is contained in:
parent
0dbb98c0a0
commit
aae2b31359
@ -129,7 +129,7 @@ class SASEvaluator:
|
||||
self._similarity_model = SentenceTransformer(self._model, device=device, use_auth_token=token)
|
||||
|
||||
@component.output_types(score=float, individual_scores=List[float])
|
||||
def run(self, ground_truths_answers: List[str], predicted_answers: List[str]) -> Dict[str, Any]:
|
||||
def run(self, ground_truth_answers: List[str], predicted_answers: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the SASEvaluator to compute the Semantic Answer Similarity (SAS) between a list of predicted answers
|
||||
and a list of ground truth answers. Both must be list of strings of same length.
|
||||
@ -143,7 +143,7 @@ class SASEvaluator:
|
||||
- `score`: Mean SAS score over all the predictions/ground-truth pairs.
|
||||
- `individual_scores`: A list of similarity scores for each prediction/ground-truth pair.
|
||||
"""
|
||||
if len(ground_truths_answers) != len(predicted_answers):
|
||||
if len(ground_truth_answers) != len(predicted_answers):
|
||||
raise ValueError("The number of predictions and labels must be the same.")
|
||||
|
||||
if len(predicted_answers) == 0:
|
||||
@ -155,7 +155,7 @@ class SASEvaluator:
|
||||
|
||||
if isinstance(self._similarity_model, CrossEncoder):
|
||||
# For Cross Encoders we create a list of pairs of predictions and labels
|
||||
sentence_pairs = list(zip(predicted_answers, ground_truths_answers))
|
||||
sentence_pairs = list(zip(predicted_answers, ground_truth_answers))
|
||||
similarity_scores = self._similarity_model.predict(
|
||||
sentence_pairs, batch_size=self._batch_size, convert_to_numpy=True
|
||||
)
|
||||
@ -174,7 +174,7 @@ class SASEvaluator:
|
||||
predicted_answers, batch_size=self._batch_size, convert_to_tensor=True
|
||||
)
|
||||
label_embeddings = self._similarity_model.encode(
|
||||
ground_truths_answers, batch_size=self._batch_size, convert_to_tensor=True
|
||||
ground_truth_answers, batch_size=self._batch_size, convert_to_tensor=True
|
||||
)
|
||||
|
||||
# Compute cosine-similarities
|
||||
|
||||
@ -51,7 +51,7 @@ class TestSASEvaluator:
|
||||
|
||||
def test_run_with_empty_inputs(self):
|
||||
evaluator = SASEvaluator()
|
||||
result = evaluator.run(ground_truths_answers=[], predicted_answers=[])
|
||||
result = evaluator.run(ground_truth_answers=[], predicted_answers=[])
|
||||
assert len(result) == 2
|
||||
assert result["score"] == 0.0
|
||||
assert result["individual_scores"] == [0.0]
|
||||
@ -68,7 +68,7 @@ class TestSASEvaluator:
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions)
|
||||
|
||||
def test_run_not_warmed_up(self):
|
||||
evaluator = SASEvaluator()
|
||||
@ -83,7 +83,7 @@ class TestSASEvaluator:
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
with pytest.raises(RuntimeError):
|
||||
evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_matching_predictions(self):
|
||||
@ -99,7 +99,7 @@ class TestSASEvaluator:
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
result = evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result["score"] == pytest.approx(1.0)
|
||||
@ -112,7 +112,7 @@ class TestSASEvaluator:
|
||||
ground_truths = ["US $2.3 billion"]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(
|
||||
ground_truths_answers=ground_truths, predicted_answers=["A construction budget of US $2.3 billion"]
|
||||
ground_truth_answers=ground_truths, predicted_answers=["A construction budget of US $2.3 billion"]
|
||||
)
|
||||
assert len(result) == 2
|
||||
assert result["score"] == pytest.approx(0.689089, abs=1e-5)
|
||||
@ -132,7 +132,7 @@ class TestSASEvaluator:
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
result = evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["score"] == pytest.approx(0.8227189)
|
||||
assert result["individual_scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5)
|
||||
@ -151,7 +151,7 @@ class TestSASEvaluator:
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
result = evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["score"] == pytest.approx(1.0)
|
||||
assert result["individual_scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
@ -170,7 +170,7 @@ class TestSASEvaluator:
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
result = evaluator.run(ground_truth_answers=ground_truths, predicted_answers=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["score"] == pytest.approx(0.999967, abs=1e-5)
|
||||
assert result["individual_scores"] == pytest.approx(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user