mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-12 16:14:05 +00:00
refactor: Refactor SASEvaluator (#6998)
* Remove preprocessing from SASEvaluator and add warm_up method * Update docstrings
This commit is contained in:
parent
06a9349095
commit
2a4e6a1de2
@ -7,9 +7,7 @@ from haystack.lazy_imports import LazyImport
|
|||||||
from haystack.utils import ComponentDevice, expit
|
from haystack.utils import ComponentDevice, expit
|
||||||
from haystack.utils.auth import Secret, deserialize_secrets_inplace
|
from haystack.utils.auth import Secret, deserialize_secrets_inplace
|
||||||
|
|
||||||
from .preprocess import _preprocess_text
|
with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as sas_import:
|
||||||
|
|
||||||
with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as metrics_import:
|
|
||||||
from sentence_transformers import CrossEncoder, SentenceTransformer, util
|
from sentence_transformers import CrossEncoder, SentenceTransformer, util
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
@ -22,17 +20,11 @@ class SASEvaluator:
|
|||||||
|
|
||||||
The SAS is computed using a pre-trained model from the Hugging Face model hub. The model can be either a
|
The SAS is computed using a pre-trained model from the Hugging Face model hub. The model can be either a
|
||||||
Bi-Encoder or a Cross-Encoder. The choice of the model is based on the `model` parameter.
|
Bi-Encoder or a Cross-Encoder. The choice of the model is based on the `model` parameter.
|
||||||
The default model is `sentence-transformers/paraphrase-multilingual-mpnet-base-v2`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
labels: List[str],
|
|
||||||
model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||||
regexes_to_ignore: Optional[List[str]] = None,
|
|
||||||
ignore_case: bool = False,
|
|
||||||
ignore_punctuation: bool = False,
|
|
||||||
ignore_numbers: bool = False,
|
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
device: Optional[ComponentDevice] = None,
|
device: Optional[ComponentDevice] = None,
|
||||||
token: Secret = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
token: Secret = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||||
@ -40,42 +32,25 @@ class SASEvaluator:
|
|||||||
"""
|
"""
|
||||||
Creates a new instance of SASEvaluator.
|
Creates a new instance of SASEvaluator.
|
||||||
|
|
||||||
:param labels: The list of expected answers.
|
|
||||||
:param model: SentenceTransformers semantic textual similarity model, should be path or string pointing to
|
:param model: SentenceTransformers semantic textual similarity model, should be path or string pointing to
|
||||||
a downloadable model.
|
a downloadable model.
|
||||||
:param regexes_to_ignore: A list of regular expressions. If provided, it removes substrings
|
|
||||||
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
|
|
||||||
:param ignore_case: If True, performs case-insensitive comparison. Defaults to False.
|
|
||||||
:param ignore_punctuation: If True, removes punctuation from both predictions and labels before
|
|
||||||
comparison. Defaults to False.
|
|
||||||
:param ignore_numbers: If True, removes numerical digits from both predictions and labels
|
|
||||||
before comparison. Defaults to False.
|
|
||||||
:param batch_size: Number of prediction-label pairs to encode at once.
|
:param batch_size: Number of prediction-label pairs to encode at once.
|
||||||
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
||||||
selected.
|
selected.
|
||||||
:param token: The Hugging Face token for HTTP bearer authorization.
|
:param token: The Hugging Face token for HTTP bearer authorization.
|
||||||
You can find your HF token at https://huggingface.co/settings/tokens.
|
You can find your HF token at https://huggingface.co/settings/tokens.
|
||||||
"""
|
"""
|
||||||
metrics_import.check()
|
sas_import.check()
|
||||||
|
|
||||||
self._labels = labels
|
|
||||||
self._model = model
|
self._model = model
|
||||||
self._regexes_to_ignore = regexes_to_ignore
|
|
||||||
self._ignore_case = ignore_case
|
|
||||||
self._ignore_punctuation = ignore_punctuation
|
|
||||||
self._ignore_numbers = ignore_numbers
|
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._device = device
|
self._device = device
|
||||||
self._token = token
|
self._token = token
|
||||||
|
self._similarity_model = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return default_to_dict(
|
return default_to_dict(
|
||||||
self,
|
self,
|
||||||
labels=self._labels,
|
|
||||||
regexes_to_ignore=self._regexes_to_ignore,
|
|
||||||
ignore_case=self._ignore_case,
|
|
||||||
ignore_punctuation=self._ignore_punctuation,
|
|
||||||
ignore_numbers=self._ignore_numbers,
|
|
||||||
model=self._model,
|
model=self._model,
|
||||||
batch_size=self._batch_size,
|
batch_size=self._batch_size,
|
||||||
device=self._device.to_dict() if self._device else None,
|
device=self._device.to_dict() if self._device else None,
|
||||||
@ -89,42 +64,54 @@ class SASEvaluator:
|
|||||||
data["init_parameters"]["device"] = ComponentDevice.from_dict(device)
|
data["init_parameters"]["device"] = ComponentDevice.from_dict(device)
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
def warm_up(self):
|
||||||
|
"""
|
||||||
|
Load the model used for evaluation
|
||||||
|
"""
|
||||||
|
token = self._token.resolve_value() if self._token else None
|
||||||
|
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
|
||||||
|
cross_encoder_used = False
|
||||||
|
if config.architectures:
|
||||||
|
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
|
||||||
|
device = ComponentDevice.resolve_device(self._device).to_torch_str()
|
||||||
|
# Based on the Model string we can load either Bi-Encoders or Cross Encoders.
|
||||||
|
# Similarity computation changes for both approaches
|
||||||
|
if cross_encoder_used:
|
||||||
|
self._similarity_model = CrossEncoder(
|
||||||
|
self._model,
|
||||||
|
device=device,
|
||||||
|
tokenizer_args={"use_auth_token": token},
|
||||||
|
automodel_args={"use_auth_token": token},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._similarity_model = SentenceTransformer(self._model, device=device, use_auth_token=token)
|
||||||
|
|
||||||
@component.output_types(sas=float, scores=List[float])
|
@component.output_types(sas=float, scores=List[float])
|
||||||
def run(self, predictions: List[str]) -> Dict[str, Any]:
|
def run(self, labels: List[str], predictions: List[str]) -> Dict[str, Any]:
|
||||||
if len(predictions) != len(self._labels):
|
"""
|
||||||
|
Run the SASEvaluator to compute the Semantic Answer Similarity (SAS) between a list of predictions and a list of
|
||||||
|
labels. Both must be list of strings of same length.
|
||||||
|
|
||||||
|
:param predictions: List of predictions.
|
||||||
|
:param labels: List of labels against which the predictions are compared.
|
||||||
|
:returns: A dictionary with the following outputs:
|
||||||
|
* `sas` - Cumulative SAS score for the entire dataset.
|
||||||
|
* `scores` - A list of similarity scores for each prediction-label pair.
|
||||||
|
"""
|
||||||
|
if len(labels) != len(predictions):
|
||||||
raise ValueError("The number of predictions and labels must be the same.")
|
raise ValueError("The number of predictions and labels must be the same.")
|
||||||
|
|
||||||
if len(predictions) == 0:
|
if len(predictions) == 0:
|
||||||
return {"sas": 0.0, "scores": [0.0]}
|
return {"sas": 0.0, "scores": [0.0]}
|
||||||
|
|
||||||
token = self._token.resolve_value() if self._token else None
|
if not self._similarity_model:
|
||||||
|
msg = "The model has not been initialized. Call warm_up() before running the evaluator."
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
predictions = _preprocess_text(
|
if isinstance(self._similarity_model, CrossEncoder):
|
||||||
predictions, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
|
|
||||||
)
|
|
||||||
labels = _preprocess_text(
|
|
||||||
self._labels, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
|
|
||||||
)
|
|
||||||
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
|
|
||||||
cross_encoder_used = False
|
|
||||||
if config.architectures:
|
|
||||||
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
|
|
||||||
|
|
||||||
device = ComponentDevice.resolve_device(self._device)
|
|
||||||
|
|
||||||
# Based on the Model string we can load either Bi-Encoders or Cross Encoders.
|
|
||||||
# Similarity computation changes for both approaches
|
|
||||||
|
|
||||||
if cross_encoder_used:
|
|
||||||
# For Cross Encoders we create a list of pairs of predictions and labels
|
# For Cross Encoders we create a list of pairs of predictions and labels
|
||||||
similarity_model = CrossEncoder(
|
|
||||||
self._model,
|
|
||||||
device=device.to_torch_str(),
|
|
||||||
tokenizer_args={"use_auth_token": token},
|
|
||||||
automodel_args={"use_auth_token": token},
|
|
||||||
)
|
|
||||||
sentence_pairs = [[pred, label] for pred, label in zip(predictions, labels)]
|
sentence_pairs = [[pred, label] for pred, label in zip(predictions, labels)]
|
||||||
similarity_scores = similarity_model.predict(
|
similarity_scores = self._similarity_model.predict(
|
||||||
sentence_pairs, batch_size=self._batch_size, convert_to_numpy=True
|
sentence_pairs, batch_size=self._batch_size, convert_to_numpy=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -138,11 +125,12 @@ class SASEvaluator:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# For Bi-encoders we create embeddings separately for predictions and labels
|
# For Bi-encoders we create embeddings separately for predictions and labels
|
||||||
similarity_model = SentenceTransformer(self._model, device=device.to_torch_str(), use_auth_token=token)
|
predictions_embeddings = self._similarity_model.encode(
|
||||||
predictions_embeddings = similarity_model.encode(
|
|
||||||
predictions, batch_size=self._batch_size, convert_to_tensor=True
|
predictions, batch_size=self._batch_size, convert_to_tensor=True
|
||||||
)
|
)
|
||||||
label_embeddings = similarity_model.encode(labels, batch_size=self._batch_size, convert_to_tensor=True)
|
label_embeddings = self._similarity_model.encode(
|
||||||
|
labels, batch_size=self._batch_size, convert_to_tensor=True
|
||||||
|
)
|
||||||
|
|
||||||
# Compute cosine-similarities
|
# Compute cosine-similarities
|
||||||
scores = util.cos_sim(predictions_embeddings, label_embeddings)
|
scores = util.cos_sim(predictions_embeddings, label_embeddings)
|
||||||
|
|||||||
@ -7,14 +7,8 @@ from haystack.utils.device import ComponentDevice
|
|||||||
class TestSASEvaluator:
|
class TestSASEvaluator:
|
||||||
def test_init_default(self, monkeypatch):
|
def test_init_default(self, monkeypatch):
|
||||||
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
|
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
|
||||||
labels = ["label1", "label2", "label3"]
|
evaluator = SASEvaluator()
|
||||||
evaluator = SASEvaluator(labels=labels)
|
|
||||||
|
|
||||||
assert evaluator._labels == labels
|
|
||||||
assert evaluator._regexes_to_ignore is None
|
|
||||||
assert evaluator._ignore_case is False
|
|
||||||
assert evaluator._ignore_punctuation is False
|
|
||||||
assert evaluator._ignore_numbers is False
|
|
||||||
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||||
assert evaluator._batch_size == 32
|
assert evaluator._batch_size == 32
|
||||||
assert evaluator._device is None
|
assert evaluator._device is None
|
||||||
@ -23,17 +17,11 @@ class TestSASEvaluator:
|
|||||||
def test_to_dict(self, monkeypatch):
|
def test_to_dict(self, monkeypatch):
|
||||||
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
|
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
|
||||||
|
|
||||||
labels = ["label1", "label2", "label3"]
|
evaluator = SASEvaluator(device=ComponentDevice.from_str("cuda:0"))
|
||||||
evaluator = SASEvaluator(labels=labels, device=ComponentDevice.from_str("cuda:0"))
|
|
||||||
|
|
||||||
expected_dict = {
|
expected_dict = {
|
||||||
"type": "haystack.components.eval.sas_evaluator.SASEvaluator",
|
"type": "haystack.components.eval.sas_evaluator.SASEvaluator",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"labels": labels,
|
|
||||||
"regexes_to_ignore": None,
|
|
||||||
"ignore_case": False,
|
|
||||||
"ignore_punctuation": False,
|
|
||||||
"ignore_numbers": False,
|
|
||||||
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"device": {"type": "single", "device": "cuda:0"},
|
"device": {"type": "single", "device": "cuda:0"},
|
||||||
@ -48,11 +36,6 @@ class TestSASEvaluator:
|
|||||||
{
|
{
|
||||||
"type": "haystack.components.eval.sas_evaluator.SASEvaluator",
|
"type": "haystack.components.eval.sas_evaluator.SASEvaluator",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"labels": ["label1", "label2", "label3"],
|
|
||||||
"regexes_to_ignore": None,
|
|
||||||
"ignore_case": False,
|
|
||||||
"ignore_punctuation": False,
|
|
||||||
"ignore_numbers": False,
|
|
||||||
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"device": {"type": "single", "device": "cuda:0"},
|
"device": {"type": "single", "device": "cuda:0"},
|
||||||
@ -61,54 +44,62 @@ class TestSASEvaluator:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert evaluator._labels == ["label1", "label2", "label3"]
|
|
||||||
assert evaluator._regexes_to_ignore is None
|
|
||||||
assert evaluator._ignore_case is False
|
|
||||||
assert evaluator._ignore_punctuation is False
|
|
||||||
assert evaluator._ignore_numbers is False
|
|
||||||
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||||
assert evaluator._batch_size == 32
|
assert evaluator._batch_size == 32
|
||||||
assert evaluator._device.to_torch_str() == "cuda:0"
|
assert evaluator._device.to_torch_str() == "cuda:0"
|
||||||
assert evaluator._token.resolve_value() == "fake-token"
|
assert evaluator._token.resolve_value() == "fake-token"
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
def test_run_with_empty_inputs(self):
|
def test_run_with_empty_inputs(self):
|
||||||
evaluator = SASEvaluator(labels=[])
|
evaluator = SASEvaluator()
|
||||||
result = evaluator.run(predictions=[])
|
result = evaluator.run(labels=[], predictions=[])
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result["sas"] == 0.0
|
assert result["sas"] == 0.0
|
||||||
assert result["scores"] == [0.0]
|
assert result["scores"] == [0.0]
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
def test_run_with_different_lengths(self):
|
def test_run_with_different_lengths(self):
|
||||||
|
evaluator = SASEvaluator()
|
||||||
labels = [
|
labels = [
|
||||||
"A construction budget of US $2.3 billion",
|
"A construction budget of US $2.3 billion",
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
]
|
]
|
||||||
evaluator = SASEvaluator(labels=labels)
|
|
||||||
|
|
||||||
predictions = [
|
predictions = [
|
||||||
"A construction budget of US $2.3 billion",
|
"A construction budget of US $2.3 billion",
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||||
]
|
]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
evaluator.run(predictions)
|
evaluator.run(labels=labels, predictions=predictions)
|
||||||
|
|
||||||
@pytest.mark.integration
|
def test_run_not_warmed_up(self):
|
||||||
def test_run_with_matching_predictions(self):
|
evaluator = SASEvaluator()
|
||||||
labels = [
|
labels = [
|
||||||
"A construction budget of US $2.3 billion",
|
"A construction budget of US $2.3 billion",
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||||
]
|
]
|
||||||
evaluator = SASEvaluator(labels=labels)
|
|
||||||
predictions = [
|
predictions = [
|
||||||
"A construction budget of US $2.3 billion",
|
"A construction budget of US $2.3 billion",
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||||
]
|
]
|
||||||
result = evaluator.run(predictions=predictions)
|
with pytest.raises(RuntimeError):
|
||||||
|
evaluator.run(labels=labels, predictions=predictions)
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_run_with_matching_predictions(self):
|
||||||
|
evaluator = SASEvaluator()
|
||||||
|
labels = [
|
||||||
|
"A construction budget of US $2.3 billion",
|
||||||
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||||
|
]
|
||||||
|
predictions = [
|
||||||
|
"A construction budget of US $2.3 billion",
|
||||||
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||||
|
]
|
||||||
|
evaluator.warm_up()
|
||||||
|
result = evaluator.run(labels=labels, predictions=predictions)
|
||||||
|
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result["sas"] == pytest.approx(1.0)
|
assert result["sas"] == pytest.approx(1.0)
|
||||||
@ -116,177 +107,68 @@ class TestSASEvaluator:
|
|||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_run_with_single_prediction(self):
|
def test_run_with_single_prediction(self):
|
||||||
labels = ["US $2.3 billion"]
|
evaluator = SASEvaluator()
|
||||||
evaluator = SASEvaluator(labels=labels)
|
|
||||||
|
|
||||||
result = evaluator.run(predictions=["A construction budget of US $2.3 billion"])
|
labels = ["US $2.3 billion"]
|
||||||
|
evaluator.warm_up()
|
||||||
|
result = evaluator.run(labels=labels, predictions=["A construction budget of US $2.3 billion"])
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result["sas"] == pytest.approx(0.689089, abs=1e-5)
|
assert result["sas"] == pytest.approx(0.689089, abs=1e-5)
|
||||||
assert result["scores"] == pytest.approx([0.689089], abs=1e-5)
|
assert result["scores"] == pytest.approx([0.689089], abs=1e-5)
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_run_with_mismatched_predictions(self):
|
def test_run_with_mismatched_predictions(self):
|
||||||
|
evaluator = SASEvaluator()
|
||||||
labels = [
|
labels = [
|
||||||
"US $2.3 billion",
|
"US $2.3 billion",
|
||||||
"Paris's cultural magnificence is symbolized by the Eiffel Tower",
|
"Paris's cultural magnificence is symbolized by the Eiffel Tower",
|
||||||
"Japan was transformed into a modernized world power after the Meiji Restoration.",
|
"Japan was transformed into a modernized world power after the Meiji Restoration.",
|
||||||
]
|
]
|
||||||
evaluator = SASEvaluator(labels=labels)
|
|
||||||
predictions = [
|
predictions = [
|
||||||
"A construction budget of US $2.3 billion",
|
"A construction budget of US $2.3 billion",
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||||
]
|
]
|
||||||
result = evaluator.run(predictions=predictions)
|
evaluator.warm_up()
|
||||||
|
result = evaluator.run(labels=labels, predictions=predictions)
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result["sas"] == pytest.approx(0.8227189)
|
assert result["sas"] == pytest.approx(0.8227189)
|
||||||
assert result["scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5)
|
assert result["scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5)
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
def test_run_with_ignore_case(self):
|
|
||||||
labels = [
|
|
||||||
"A construction budget of US $2.3 BILLION",
|
|
||||||
"The EIFFEL TOWER, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The MEIJI RESTORATION in 1868 transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
evaluator = SASEvaluator(labels=labels, ignore_case=True)
|
|
||||||
predictions = [
|
|
||||||
"A construction budget of US $2.3 billion",
|
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
result = evaluator.run(predictions=predictions)
|
|
||||||
assert len(result) == 2
|
|
||||||
assert result["sas"] == pytest.approx(1.0)
|
|
||||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
def test_run_with_ignore_punctuation(self):
|
|
||||||
labels = [
|
|
||||||
"A construction budget of US $2.3 billion",
|
|
||||||
"The Eiffel Tower completed in 1889 symbolizes Paris's cultural magnificence",
|
|
||||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power",
|
|
||||||
]
|
|
||||||
evaluator = SASEvaluator(labels=labels, ignore_punctuation=True)
|
|
||||||
predictions = [
|
|
||||||
"A construction budget of US $2.3 billion",
|
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
result = evaluator.run(predictions=predictions)
|
|
||||||
assert len(result) == 2
|
|
||||||
assert result["sas"] == pytest.approx(1.0)
|
|
||||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
def test_run_with_ignore_numbers(self):
|
|
||||||
labels = [
|
|
||||||
"A construction budget of US $10.3 billion",
|
|
||||||
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
evaluator = SASEvaluator(labels=labels, ignore_numbers=True)
|
|
||||||
predictions = [
|
|
||||||
"A construction budget of US $2.3 billion",
|
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
result = evaluator.run(predictions=predictions)
|
|
||||||
assert result["sas"] == pytest.approx(1.0)
|
|
||||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
def test_run_with_regex_to_ignore(self):
|
|
||||||
labels = [
|
|
||||||
"A construction budget of US $10.3 billion",
|
|
||||||
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
evaluator = SASEvaluator(labels=labels, regexes_to_ignore=[r"\d+"])
|
|
||||||
predictions = [
|
|
||||||
"A construction budget of US $2.3 billion",
|
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
result = evaluator.run(predictions=predictions)
|
|
||||||
assert len(result) == 2
|
|
||||||
assert result["sas"] == pytest.approx(1.0)
|
|
||||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
def test_run_with_multiple_regex_to_ignore(self):
|
|
||||||
labels = [
|
|
||||||
"A construction budget of US $10.3 billion",
|
|
||||||
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
evaluator = SASEvaluator(labels=labels, regexes_to_ignore=[r"\d+", r"[^\w\s]"])
|
|
||||||
predictions = [
|
|
||||||
"A construction budget of US $2.3 billion",
|
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
result = evaluator.run(predictions=predictions)
|
|
||||||
assert len(result) == 2
|
|
||||||
assert result["sas"] == pytest.approx(1.0)
|
|
||||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
def test_run_with_multiple_ignore_parameters(self):
|
|
||||||
labels = [
|
|
||||||
"A construction budget of US $10.3 billion",
|
|
||||||
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
evaluator = SASEvaluator(
|
|
||||||
labels=labels,
|
|
||||||
ignore_numbers=True,
|
|
||||||
ignore_punctuation=True,
|
|
||||||
ignore_case=True,
|
|
||||||
regexes_to_ignore=[r"[^\w\s\d]+"],
|
|
||||||
)
|
|
||||||
predictions = [
|
|
||||||
"A construction budget of US $2.3 billion",
|
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
||||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
|
||||||
]
|
|
||||||
result = evaluator.run(predictions=predictions)
|
|
||||||
assert len(result) == 2
|
|
||||||
assert result["sas"] == pytest.approx(1.0)
|
|
||||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_run_with_bi_encoder_model(self):
|
def test_run_with_bi_encoder_model(self):
|
||||||
|
evaluator = SASEvaluator(model="sentence-transformers/all-mpnet-base-v2")
|
||||||
labels = [
|
labels = [
|
||||||
"A construction budget of US $2.3 billion",
|
"A construction budget of US $2.3 billion",
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||||
]
|
]
|
||||||
evaluator = SASEvaluator(labels=labels, model="sentence-transformers/all-mpnet-base-v2")
|
|
||||||
predictions = [
|
predictions = [
|
||||||
"A construction budget of US $2.3 billion",
|
"A construction budget of US $2.3 billion",
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||||
]
|
]
|
||||||
result = evaluator.run(predictions=predictions)
|
evaluator.warm_up()
|
||||||
|
result = evaluator.run(labels=labels, predictions=predictions)
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result["sas"] == pytest.approx(1.0)
|
assert result["sas"] == pytest.approx(1.0)
|
||||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_run_with_cross_encoder_model(self):
|
def test_run_with_cross_encoder_model(self):
|
||||||
|
evaluator = SASEvaluator(model="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||||
labels = [
|
labels = [
|
||||||
"A construction budget of US $2.3 billion",
|
"A construction budget of US $2.3 billion",
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||||
]
|
]
|
||||||
evaluator = SASEvaluator(labels=labels, model="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
|
||||||
predictions = [
|
predictions = [
|
||||||
"A construction budget of US $2.3 billion",
|
"A construction budget of US $2.3 billion",
|
||||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||||
]
|
]
|
||||||
result = evaluator.run(predictions=predictions)
|
evaluator.warm_up()
|
||||||
|
result = evaluator.run(labels=labels, predictions=predictions)
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result["sas"] == pytest.approx(0.999967, abs=1e-5)
|
assert result["sas"] == pytest.approx(0.999967, abs=1e-5)
|
||||||
assert result["scores"] == pytest.approx([0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5)
|
assert result["scores"] == pytest.approx([0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user