refactor: Refactor SASEvaluator (#6998)

* Remove preprocessing from SASEvaluator and add warm_up method

* Update docstrings
This commit is contained in:
Silvano Cerza 2024-02-15 16:05:43 +01:00 committed by GitHub
parent 06a9349095
commit 2a4e6a1de2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 85 additions and 215 deletions

View File

@ -7,9 +7,7 @@ from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, expit
from haystack.utils.auth import Secret, deserialize_secrets_inplace
from .preprocess import _preprocess_text
with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as metrics_import:
with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as sas_import:
from sentence_transformers import CrossEncoder, SentenceTransformer, util
from transformers import AutoConfig
@ -22,17 +20,11 @@ class SASEvaluator:
The SAS is computed using a pre-trained model from the Hugging Face model hub. The model can be either a
Bi-Encoder or a Cross-Encoder. The choice of the model is based on the `model` parameter.
The default model is `sentence-transformers/paraphrase-multilingual-mpnet-base-v2`.
"""
def __init__(
self,
labels: List[str],
model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
regexes_to_ignore: Optional[List[str]] = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
batch_size: int = 32,
device: Optional[ComponentDevice] = None,
token: Secret = Secret.from_env_var("HF_API_TOKEN", strict=False),
@ -40,42 +32,25 @@ class SASEvaluator:
"""
Creates a new instance of SASEvaluator.
:param labels: The list of expected answers.
:param model: SentenceTransformers semantic textual similarity model, should be path or string pointing to
a downloadable model.
:param regexes_to_ignore: A list of regular expressions. If provided, it removes substrings
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
:param ignore_case: If True, performs case-insensitive comparison. Defaults to False.
:param ignore_punctuation: If True, removes punctuation from both predictions and labels before
comparison. Defaults to False.
:param ignore_numbers: If True, removes numerical digits from both predictions and labels
before comparison. Defaults to False.
:param batch_size: Number of prediction-label pairs to encode at once.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected.
:param token: The Hugging Face token for HTTP bearer authorization.
You can find your HF token at https://huggingface.co/settings/tokens.
"""
metrics_import.check()
sas_import.check()
self._labels = labels
self._model = model
self._regexes_to_ignore = regexes_to_ignore
self._ignore_case = ignore_case
self._ignore_punctuation = ignore_punctuation
self._ignore_numbers = ignore_numbers
self._batch_size = batch_size
self._device = device
self._token = token
self._similarity_model = None
def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
labels=self._labels,
regexes_to_ignore=self._regexes_to_ignore,
ignore_case=self._ignore_case,
ignore_punctuation=self._ignore_punctuation,
ignore_numbers=self._ignore_numbers,
model=self._model,
batch_size=self._batch_size,
device=self._device.to_dict() if self._device else None,
@ -89,42 +64,54 @@ class SASEvaluator:
data["init_parameters"]["device"] = ComponentDevice.from_dict(device)
return default_from_dict(cls, data)
def warm_up(self):
"""
Load the model used for evaluation
"""
token = self._token.resolve_value() if self._token else None
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
cross_encoder_used = False
if config.architectures:
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
device = ComponentDevice.resolve_device(self._device).to_torch_str()
# Based on the Model string we can load either Bi-Encoders or Cross Encoders.
# Similarity computation changes for both approaches
if cross_encoder_used:
self._similarity_model = CrossEncoder(
self._model,
device=device,
tokenizer_args={"use_auth_token": token},
automodel_args={"use_auth_token": token},
)
else:
self._similarity_model = SentenceTransformer(self._model, device=device, use_auth_token=token)
@component.output_types(sas=float, scores=List[float])
def run(self, predictions: List[str]) -> Dict[str, Any]:
if len(predictions) != len(self._labels):
def run(self, labels: List[str], predictions: List[str]) -> Dict[str, Any]:
"""
Run the SASEvaluator to compute the Semantic Answer Similarity (SAS) between a list of predictions and a list of
labels. Both must be list of strings of same length.
:param predictions: List of predictions.
:param labels: List of labels against which the predictions are compared.
:returns: A dictionary with the following outputs:
* `sas` - Cumulative SAS score for the entire dataset.
* `scores` - A list of similarity scores for each prediction-label pair.
"""
if len(labels) != len(predictions):
raise ValueError("The number of predictions and labels must be the same.")
if len(predictions) == 0:
return {"sas": 0.0, "scores": [0.0]}
token = self._token.resolve_value() if self._token else None
if not self._similarity_model:
msg = "The model has not been initialized. Call warm_up() before running the evaluator."
raise RuntimeError(msg)
predictions = _preprocess_text(
predictions, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
)
labels = _preprocess_text(
self._labels, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
)
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
cross_encoder_used = False
if config.architectures:
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
device = ComponentDevice.resolve_device(self._device)
# Based on the Model string we can load either Bi-Encoders or Cross Encoders.
# Similarity computation changes for both approaches
if cross_encoder_used:
if isinstance(self._similarity_model, CrossEncoder):
# For Cross Encoders we create a list of pairs of predictions and labels
similarity_model = CrossEncoder(
self._model,
device=device.to_torch_str(),
tokenizer_args={"use_auth_token": token},
automodel_args={"use_auth_token": token},
)
sentence_pairs = [[pred, label] for pred, label in zip(predictions, labels)]
similarity_scores = similarity_model.predict(
similarity_scores = self._similarity_model.predict(
sentence_pairs, batch_size=self._batch_size, convert_to_numpy=True
)
@ -138,11 +125,12 @@ class SASEvaluator:
else:
# For Bi-encoders we create embeddings separately for predictions and labels
similarity_model = SentenceTransformer(self._model, device=device.to_torch_str(), use_auth_token=token)
predictions_embeddings = similarity_model.encode(
predictions_embeddings = self._similarity_model.encode(
predictions, batch_size=self._batch_size, convert_to_tensor=True
)
label_embeddings = similarity_model.encode(labels, batch_size=self._batch_size, convert_to_tensor=True)
label_embeddings = self._similarity_model.encode(
labels, batch_size=self._batch_size, convert_to_tensor=True
)
# Compute cosine-similarities
scores = util.cos_sim(predictions_embeddings, label_embeddings)

View File

@ -7,14 +7,8 @@ from haystack.utils.device import ComponentDevice
class TestSASEvaluator:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
labels = ["label1", "label2", "label3"]
evaluator = SASEvaluator(labels=labels)
evaluator = SASEvaluator()
assert evaluator._labels == labels
assert evaluator._regexes_to_ignore is None
assert evaluator._ignore_case is False
assert evaluator._ignore_punctuation is False
assert evaluator._ignore_numbers is False
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
assert evaluator._batch_size == 32
assert evaluator._device is None
@ -23,17 +17,11 @@ class TestSASEvaluator:
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
labels = ["label1", "label2", "label3"]
evaluator = SASEvaluator(labels=labels, device=ComponentDevice.from_str("cuda:0"))
evaluator = SASEvaluator(device=ComponentDevice.from_str("cuda:0"))
expected_dict = {
"type": "haystack.components.eval.sas_evaluator.SASEvaluator",
"init_parameters": {
"labels": labels,
"regexes_to_ignore": None,
"ignore_case": False,
"ignore_punctuation": False,
"ignore_numbers": False,
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
"batch_size": 32,
"device": {"type": "single", "device": "cuda:0"},
@ -48,11 +36,6 @@ class TestSASEvaluator:
{
"type": "haystack.components.eval.sas_evaluator.SASEvaluator",
"init_parameters": {
"labels": ["label1", "label2", "label3"],
"regexes_to_ignore": None,
"ignore_case": False,
"ignore_punctuation": False,
"ignore_numbers": False,
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
"batch_size": 32,
"device": {"type": "single", "device": "cuda:0"},
@ -61,54 +44,62 @@ class TestSASEvaluator:
}
)
assert evaluator._labels == ["label1", "label2", "label3"]
assert evaluator._regexes_to_ignore is None
assert evaluator._ignore_case is False
assert evaluator._ignore_punctuation is False
assert evaluator._ignore_numbers is False
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
assert evaluator._batch_size == 32
assert evaluator._device.to_torch_str() == "cuda:0"
assert evaluator._token.resolve_value() == "fake-token"
@pytest.mark.integration
def test_run_with_empty_inputs(self):
evaluator = SASEvaluator(labels=[])
result = evaluator.run(predictions=[])
evaluator = SASEvaluator()
result = evaluator.run(labels=[], predictions=[])
assert len(result) == 2
assert result["sas"] == 0.0
assert result["scores"] == [0.0]
@pytest.mark.integration
def test_run_with_different_lengths(self):
evaluator = SASEvaluator()
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
]
evaluator = SASEvaluator(labels=labels)
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
with pytest.raises(ValueError):
evaluator.run(predictions)
evaluator.run(labels=labels, predictions=predictions)
@pytest.mark.integration
def test_run_with_matching_predictions(self):
def test_run_not_warmed_up(self):
evaluator = SASEvaluator()
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluator = SASEvaluator(labels=labels)
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
result = evaluator.run(predictions=predictions)
with pytest.raises(RuntimeError):
evaluator.run(labels=labels, predictions=predictions)
@pytest.mark.integration
def test_run_with_matching_predictions(self):
evaluator = SASEvaluator()
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluator.warm_up()
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 2
assert result["sas"] == pytest.approx(1.0)
@ -116,177 +107,68 @@ class TestSASEvaluator:
@pytest.mark.integration
def test_run_with_single_prediction(self):
labels = ["US $2.3 billion"]
evaluator = SASEvaluator(labels=labels)
evaluator = SASEvaluator()
result = evaluator.run(predictions=["A construction budget of US $2.3 billion"])
labels = ["US $2.3 billion"]
evaluator.warm_up()
result = evaluator.run(labels=labels, predictions=["A construction budget of US $2.3 billion"])
assert len(result) == 2
assert result["sas"] == pytest.approx(0.689089, abs=1e-5)
assert result["scores"] == pytest.approx([0.689089], abs=1e-5)
@pytest.mark.integration
def test_run_with_mismatched_predictions(self):
evaluator = SASEvaluator()
labels = [
"US $2.3 billion",
"Paris's cultural magnificence is symbolized by the Eiffel Tower",
"Japan was transformed into a modernized world power after the Meiji Restoration.",
]
evaluator = SASEvaluator(labels=labels)
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
result = evaluator.run(predictions=predictions)
evaluator.warm_up()
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 2
assert result["sas"] == pytest.approx(0.8227189)
assert result["scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5)
@pytest.mark.integration
def test_run_with_ignore_case(self):
labels = [
"A construction budget of US $2.3 BILLION",
"The EIFFEL TOWER, completed in 1889, symbolizes Paris's cultural magnificence.",
"The MEIJI RESTORATION in 1868 transformed Japan into a modernized world power.",
]
evaluator = SASEvaluator(labels=labels, ignore_case=True)
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
result = evaluator.run(predictions=predictions)
assert len(result) == 2
assert result["sas"] == pytest.approx(1.0)
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_run_with_ignore_punctuation(self):
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower completed in 1889 symbolizes Paris's cultural magnificence",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power",
]
evaluator = SASEvaluator(labels=labels, ignore_punctuation=True)
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
]
result = evaluator.run(predictions=predictions)
assert len(result) == 2
assert result["sas"] == pytest.approx(1.0)
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_run_with_ignore_numbers(self):
labels = [
"A construction budget of US $10.3 billion",
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
]
evaluator = SASEvaluator(labels=labels, ignore_numbers=True)
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
]
result = evaluator.run(predictions=predictions)
assert result["sas"] == pytest.approx(1.0)
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_run_with_regex_to_ignore(self):
labels = [
"A construction budget of US $10.3 billion",
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
]
evaluator = SASEvaluator(labels=labels, regexes_to_ignore=[r"\d+"])
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
]
result = evaluator.run(predictions=predictions)
assert len(result) == 2
assert result["sas"] == pytest.approx(1.0)
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_run_with_multiple_regex_to_ignore(self):
labels = [
"A construction budget of US $10.3 billion",
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
]
evaluator = SASEvaluator(labels=labels, regexes_to_ignore=[r"\d+", r"[^\w\s]"])
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
]
result = evaluator.run(predictions=predictions)
assert len(result) == 2
assert result["sas"] == pytest.approx(1.0)
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_run_with_multiple_ignore_parameters(self):
labels = [
"A construction budget of US $10.3 billion",
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
]
evaluator = SASEvaluator(
labels=labels,
ignore_numbers=True,
ignore_punctuation=True,
ignore_case=True,
regexes_to_ignore=[r"[^\w\s\d]+"],
)
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
]
result = evaluator.run(predictions=predictions)
assert len(result) == 2
assert result["sas"] == pytest.approx(1.0)
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_run_with_bi_encoder_model(self):
evaluator = SASEvaluator(model="sentence-transformers/all-mpnet-base-v2")
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluator = SASEvaluator(labels=labels, model="sentence-transformers/all-mpnet-base-v2")
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
result = evaluator.run(predictions=predictions)
evaluator.warm_up()
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 2
assert result["sas"] == pytest.approx(1.0)
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_run_with_cross_encoder_model(self):
evaluator = SASEvaluator(model="cross-encoder/ms-marco-MiniLM-L-6-v2")
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluator = SASEvaluator(labels=labels, model="cross-encoder/ms-marco-MiniLM-L-6-v2")
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
result = evaluator.run(predictions=predictions)
evaluator.warm_up()
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 2
assert result["sas"] == pytest.approx(0.999967, abs=1e-5)
assert result["scores"] == pytest.approx([0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5)