mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 18:59:28 +00:00
refactor: Refactor SASEvaluator (#6998)
* Remove preprocessing from SASEvaluator and add warm_up method * Update docstrings
This commit is contained in:
parent
06a9349095
commit
2a4e6a1de2
@ -7,9 +7,7 @@ from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, expit
|
||||
from haystack.utils.auth import Secret, deserialize_secrets_inplace
|
||||
|
||||
from .preprocess import _preprocess_text
|
||||
|
||||
with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as metrics_import:
|
||||
with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as sas_import:
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer, util
|
||||
from transformers import AutoConfig
|
||||
|
||||
@ -22,17 +20,11 @@ class SASEvaluator:
|
||||
|
||||
The SAS is computed using a pre-trained model from the Hugging Face model hub. The model can be either a
|
||||
Bi-Encoder or a Cross-Encoder. The choice of the model is based on the `model` parameter.
|
||||
The default model is `sentence-transformers/paraphrase-multilingual-mpnet-base-v2`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
labels: List[str],
|
||||
model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
regexes_to_ignore: Optional[List[str]] = None,
|
||||
ignore_case: bool = False,
|
||||
ignore_punctuation: bool = False,
|
||||
ignore_numbers: bool = False,
|
||||
batch_size: int = 32,
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Secret = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
@ -40,42 +32,25 @@ class SASEvaluator:
|
||||
"""
|
||||
Creates a new instance of SASEvaluator.
|
||||
|
||||
:param labels: The list of expected answers.
|
||||
:param model: SentenceTransformers semantic textual similarity model, should be path or string pointing to
|
||||
a downloadable model.
|
||||
:param regexes_to_ignore: A list of regular expressions. If provided, it removes substrings
|
||||
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
|
||||
:param ignore_case: If True, performs case-insensitive comparison. Defaults to False.
|
||||
:param ignore_punctuation: If True, removes punctuation from both predictions and labels before
|
||||
comparison. Defaults to False.
|
||||
:param ignore_numbers: If True, removes numerical digits from both predictions and labels
|
||||
before comparison. Defaults to False.
|
||||
:param batch_size: Number of prediction-label pairs to encode at once.
|
||||
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
||||
selected.
|
||||
:param token: The Hugging Face token for HTTP bearer authorization.
|
||||
You can find your HF token at https://huggingface.co/settings/tokens.
|
||||
"""
|
||||
metrics_import.check()
|
||||
sas_import.check()
|
||||
|
||||
self._labels = labels
|
||||
self._model = model
|
||||
self._regexes_to_ignore = regexes_to_ignore
|
||||
self._ignore_case = ignore_case
|
||||
self._ignore_punctuation = ignore_punctuation
|
||||
self._ignore_numbers = ignore_numbers
|
||||
self._batch_size = batch_size
|
||||
self._device = device
|
||||
self._token = token
|
||||
self._similarity_model = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return default_to_dict(
|
||||
self,
|
||||
labels=self._labels,
|
||||
regexes_to_ignore=self._regexes_to_ignore,
|
||||
ignore_case=self._ignore_case,
|
||||
ignore_punctuation=self._ignore_punctuation,
|
||||
ignore_numbers=self._ignore_numbers,
|
||||
model=self._model,
|
||||
batch_size=self._batch_size,
|
||||
device=self._device.to_dict() if self._device else None,
|
||||
@ -89,42 +64,54 @@ class SASEvaluator:
|
||||
data["init_parameters"]["device"] = ComponentDevice.from_dict(device)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Load the model used for evaluation
|
||||
"""
|
||||
token = self._token.resolve_value() if self._token else None
|
||||
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
|
||||
cross_encoder_used = False
|
||||
if config.architectures:
|
||||
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
|
||||
device = ComponentDevice.resolve_device(self._device).to_torch_str()
|
||||
# Based on the Model string we can load either Bi-Encoders or Cross Encoders.
|
||||
# Similarity computation changes for both approaches
|
||||
if cross_encoder_used:
|
||||
self._similarity_model = CrossEncoder(
|
||||
self._model,
|
||||
device=device,
|
||||
tokenizer_args={"use_auth_token": token},
|
||||
automodel_args={"use_auth_token": token},
|
||||
)
|
||||
else:
|
||||
self._similarity_model = SentenceTransformer(self._model, device=device, use_auth_token=token)
|
||||
|
||||
@component.output_types(sas=float, scores=List[float])
|
||||
def run(self, predictions: List[str]) -> Dict[str, Any]:
|
||||
if len(predictions) != len(self._labels):
|
||||
def run(self, labels: List[str], predictions: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the SASEvaluator to compute the Semantic Answer Similarity (SAS) between a list of predictions and a list of
|
||||
labels. Both must be list of strings of same length.
|
||||
|
||||
:param predictions: List of predictions.
|
||||
:param labels: List of labels against which the predictions are compared.
|
||||
:returns: A dictionary with the following outputs:
|
||||
* `sas` - Cumulative SAS score for the entire dataset.
|
||||
* `scores` - A list of similarity scores for each prediction-label pair.
|
||||
"""
|
||||
if len(labels) != len(predictions):
|
||||
raise ValueError("The number of predictions and labels must be the same.")
|
||||
|
||||
if len(predictions) == 0:
|
||||
return {"sas": 0.0, "scores": [0.0]}
|
||||
|
||||
token = self._token.resolve_value() if self._token else None
|
||||
if not self._similarity_model:
|
||||
msg = "The model has not been initialized. Call warm_up() before running the evaluator."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
predictions = _preprocess_text(
|
||||
predictions, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
|
||||
)
|
||||
labels = _preprocess_text(
|
||||
self._labels, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
|
||||
)
|
||||
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
|
||||
cross_encoder_used = False
|
||||
if config.architectures:
|
||||
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
|
||||
|
||||
device = ComponentDevice.resolve_device(self._device)
|
||||
|
||||
# Based on the Model string we can load either Bi-Encoders or Cross Encoders.
|
||||
# Similarity computation changes for both approaches
|
||||
|
||||
if cross_encoder_used:
|
||||
if isinstance(self._similarity_model, CrossEncoder):
|
||||
# For Cross Encoders we create a list of pairs of predictions and labels
|
||||
similarity_model = CrossEncoder(
|
||||
self._model,
|
||||
device=device.to_torch_str(),
|
||||
tokenizer_args={"use_auth_token": token},
|
||||
automodel_args={"use_auth_token": token},
|
||||
)
|
||||
sentence_pairs = [[pred, label] for pred, label in zip(predictions, labels)]
|
||||
similarity_scores = similarity_model.predict(
|
||||
similarity_scores = self._similarity_model.predict(
|
||||
sentence_pairs, batch_size=self._batch_size, convert_to_numpy=True
|
||||
)
|
||||
|
||||
@ -138,11 +125,12 @@ class SASEvaluator:
|
||||
|
||||
else:
|
||||
# For Bi-encoders we create embeddings separately for predictions and labels
|
||||
similarity_model = SentenceTransformer(self._model, device=device.to_torch_str(), use_auth_token=token)
|
||||
predictions_embeddings = similarity_model.encode(
|
||||
predictions_embeddings = self._similarity_model.encode(
|
||||
predictions, batch_size=self._batch_size, convert_to_tensor=True
|
||||
)
|
||||
label_embeddings = similarity_model.encode(labels, batch_size=self._batch_size, convert_to_tensor=True)
|
||||
label_embeddings = self._similarity_model.encode(
|
||||
labels, batch_size=self._batch_size, convert_to_tensor=True
|
||||
)
|
||||
|
||||
# Compute cosine-similarities
|
||||
scores = util.cos_sim(predictions_embeddings, label_embeddings)
|
||||
|
||||
@ -7,14 +7,8 @@ from haystack.utils.device import ComponentDevice
|
||||
class TestSASEvaluator:
|
||||
def test_init_default(self, monkeypatch):
|
||||
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
|
||||
labels = ["label1", "label2", "label3"]
|
||||
evaluator = SASEvaluator(labels=labels)
|
||||
evaluator = SASEvaluator()
|
||||
|
||||
assert evaluator._labels == labels
|
||||
assert evaluator._regexes_to_ignore is None
|
||||
assert evaluator._ignore_case is False
|
||||
assert evaluator._ignore_punctuation is False
|
||||
assert evaluator._ignore_numbers is False
|
||||
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
assert evaluator._batch_size == 32
|
||||
assert evaluator._device is None
|
||||
@ -23,17 +17,11 @@ class TestSASEvaluator:
|
||||
def test_to_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
|
||||
|
||||
labels = ["label1", "label2", "label3"]
|
||||
evaluator = SASEvaluator(labels=labels, device=ComponentDevice.from_str("cuda:0"))
|
||||
evaluator = SASEvaluator(device=ComponentDevice.from_str("cuda:0"))
|
||||
|
||||
expected_dict = {
|
||||
"type": "haystack.components.eval.sas_evaluator.SASEvaluator",
|
||||
"init_parameters": {
|
||||
"labels": labels,
|
||||
"regexes_to_ignore": None,
|
||||
"ignore_case": False,
|
||||
"ignore_punctuation": False,
|
||||
"ignore_numbers": False,
|
||||
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
"batch_size": 32,
|
||||
"device": {"type": "single", "device": "cuda:0"},
|
||||
@ -48,11 +36,6 @@ class TestSASEvaluator:
|
||||
{
|
||||
"type": "haystack.components.eval.sas_evaluator.SASEvaluator",
|
||||
"init_parameters": {
|
||||
"labels": ["label1", "label2", "label3"],
|
||||
"regexes_to_ignore": None,
|
||||
"ignore_case": False,
|
||||
"ignore_punctuation": False,
|
||||
"ignore_numbers": False,
|
||||
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
"batch_size": 32,
|
||||
"device": {"type": "single", "device": "cuda:0"},
|
||||
@ -61,54 +44,62 @@ class TestSASEvaluator:
|
||||
}
|
||||
)
|
||||
|
||||
assert evaluator._labels == ["label1", "label2", "label3"]
|
||||
assert evaluator._regexes_to_ignore is None
|
||||
assert evaluator._ignore_case is False
|
||||
assert evaluator._ignore_punctuation is False
|
||||
assert evaluator._ignore_numbers is False
|
||||
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
assert evaluator._batch_size == 32
|
||||
assert evaluator._device.to_torch_str() == "cuda:0"
|
||||
assert evaluator._token.resolve_value() == "fake-token"
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_empty_inputs(self):
|
||||
evaluator = SASEvaluator(labels=[])
|
||||
result = evaluator.run(predictions=[])
|
||||
evaluator = SASEvaluator()
|
||||
result = evaluator.run(labels=[], predictions=[])
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == 0.0
|
||||
assert result["scores"] == [0.0]
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_different_lengths(self):
|
||||
evaluator = SASEvaluator()
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
]
|
||||
evaluator = SASEvaluator(labels=labels)
|
||||
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
evaluator.run(predictions)
|
||||
evaluator.run(labels=labels, predictions=predictions)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_matching_predictions(self):
|
||||
def test_run_not_warmed_up(self):
|
||||
evaluator = SASEvaluator()
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator = SASEvaluator(labels=labels)
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
with pytest.raises(RuntimeError):
|
||||
evaluator.run(labels=labels, predictions=predictions)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_matching_predictions(self):
|
||||
evaluator = SASEvaluator()
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == pytest.approx(1.0)
|
||||
@ -116,177 +107,68 @@ class TestSASEvaluator:
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_single_prediction(self):
|
||||
labels = ["US $2.3 billion"]
|
||||
evaluator = SASEvaluator(labels=labels)
|
||||
evaluator = SASEvaluator()
|
||||
|
||||
result = evaluator.run(predictions=["A construction budget of US $2.3 billion"])
|
||||
labels = ["US $2.3 billion"]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(labels=labels, predictions=["A construction budget of US $2.3 billion"])
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == pytest.approx(0.689089, abs=1e-5)
|
||||
assert result["scores"] == pytest.approx([0.689089], abs=1e-5)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_mismatched_predictions(self):
|
||||
evaluator = SASEvaluator()
|
||||
labels = [
|
||||
"US $2.3 billion",
|
||||
"Paris's cultural magnificence is symbolized by the Eiffel Tower",
|
||||
"Japan was transformed into a modernized world power after the Meiji Restoration.",
|
||||
]
|
||||
evaluator = SASEvaluator(labels=labels)
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == pytest.approx(0.8227189)
|
||||
assert result["scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_ignore_case(self):
|
||||
labels = [
|
||||
"A construction budget of US $2.3 BILLION",
|
||||
"The EIFFEL TOWER, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The MEIJI RESTORATION in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator = SASEvaluator(labels=labels, ignore_case=True)
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == pytest.approx(1.0)
|
||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_ignore_punctuation(self):
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower completed in 1889 symbolizes Paris's cultural magnificence",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power",
|
||||
]
|
||||
evaluator = SASEvaluator(labels=labels, ignore_punctuation=True)
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == pytest.approx(1.0)
|
||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_ignore_numbers(self):
|
||||
labels = [
|
||||
"A construction budget of US $10.3 billion",
|
||||
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator = SASEvaluator(labels=labels, ignore_numbers=True)
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert result["sas"] == pytest.approx(1.0)
|
||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_regex_to_ignore(self):
|
||||
labels = [
|
||||
"A construction budget of US $10.3 billion",
|
||||
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator = SASEvaluator(labels=labels, regexes_to_ignore=[r"\d+"])
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == pytest.approx(1.0)
|
||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_multiple_regex_to_ignore(self):
|
||||
labels = [
|
||||
"A construction budget of US $10.3 billion",
|
||||
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator = SASEvaluator(labels=labels, regexes_to_ignore=[r"\d+", r"[^\w\s]"])
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == pytest.approx(1.0)
|
||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_multiple_ignore_parameters(self):
|
||||
labels = [
|
||||
"A construction budget of US $10.3 billion",
|
||||
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator = SASEvaluator(
|
||||
labels=labels,
|
||||
ignore_numbers=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_case=True,
|
||||
regexes_to_ignore=[r"[^\w\s\d]+"],
|
||||
)
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == pytest.approx(1.0)
|
||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_bi_encoder_model(self):
|
||||
evaluator = SASEvaluator(model="sentence-transformers/all-mpnet-base-v2")
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator = SASEvaluator(labels=labels, model="sentence-transformers/all-mpnet-base-v2")
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == pytest.approx(1.0)
|
||||
assert result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_cross_encoder_model(self):
|
||||
evaluator = SASEvaluator(model="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator = SASEvaluator(labels=labels, model="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["sas"] == pytest.approx(0.999967, abs=1e-5)
|
||||
assert result["scores"] == pytest.approx([0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user