feat: Add scaling and thresholding of the similarity ranker scores (#6683)

* Add scale_score functionality to the TransformersSimilarityRanker

* Updated test to check scores

* Use pytest approx when comparing floats

* Updated how scale score works and added calibration factor. Started to add score threshold.

* Add support for score_threshold

* Add some parameters to the run method

* Add release notes

* Fix mypy

* Be more tolerant on the score values

* Adding unit test for scale_score=False

* Add unit test for score threshold

* Update tests

* Rename test

* Fix typo

* PR comments
This commit is contained in:
Sebastian Husch Lee 2024-01-08 09:05:24 +01:00 committed by GitHub
parent 552f0e394b
commit beade1cef9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 130 additions and 16 deletions

View File

@ -43,6 +43,9 @@ class TransformersSimilarityRanker:
top_k: int = 10,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
scale_score: bool = True,
calibration_factor: Optional[float] = 1.0,
score_threshold: Optional[float] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
):
"""
@ -57,6 +60,11 @@ class TransformersSimilarityRanker:
:param top_k: The maximum number of Documents to return per query.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
:param scale_score: Whether the raw logit predictions will be scaled using a Sigmoid activation function.
Set this to False if you do not want any scaling of the raw logit predictions.
:param calibration_factor: Factor used for calibrating probabilities calculated by
`sigmoid(logits * calibration_factor)`. This is only used if `scale_score` is set to True.
:param score_threshold: If provided only returns documents with a score above this threshold.
:param model_kwargs: Additional keyword arguments passed to `AutoModelForSequenceClassification.from_pretrained`
when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass,
see the model's documentation.
@ -73,6 +81,13 @@ class TransformersSimilarityRanker:
self.tokenizer = None
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.scale_score = scale_score
self.calibration_factor = calibration_factor
if self.scale_score and self.calibration_factor is None:
raise ValueError(
f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}"
)
self.score_threshold = score_threshold
self.model_kwargs = model_kwargs or {}
def _get_telemetry_data(self) -> Dict[str, Any]:
@ -106,28 +121,53 @@ class TransformersSimilarityRanker:
top_k=self.top_k,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
scale_score=self.scale_score,
calibration_factor=self.calibration_factor,
score_threshold=self.score_threshold,
model_kwargs=self.model_kwargs,
)
@component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
def run(
self,
query: str,
documents: List[Document],
top_k: Optional[int] = None,
scale_score: Optional[bool] = None,
calibration_factor: Optional[float] = None,
score_threshold: Optional[float] = None,
):
"""
Returns a list of Documents ranked by their similarity to the given query.
:param query: Query string.
:param documents: List of Documents.
:param top_k: The maximum number of Documents you want the Ranker to return.
:param scale_score: Whether the raw logit predictions will be scaled using a Sigmoid activation function.
Set this to False if you do not want any scaling of the raw logit predictions.
:param calibration_factor: Factor used for calibrating probabilities calculated by
`sigmoid(logits * calibration_factor)`. This is only used if `scale_score` is set to True.
:param score_threshold: If provided only returns documents with a score above this threshold.
:return: List of Documents sorted by their similarity to the query with the most similar Documents appearing first.
"""
if not documents:
return {"documents": []}
if top_k is None:
top_k = self.top_k
elif top_k <= 0:
top_k = top_k or self.top_k
if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
scale_score = scale_score or self.scale_score
calibration_factor = calibration_factor or self.calibration_factor
if scale_score and calibration_factor is None:
raise ValueError(
f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}"
)
if score_threshold is None:
score_threshold = self.score_threshold
# If a model path is provided but the model isn't loaded
if self.model_name_or_path and not self.model:
raise ComponentError(
@ -150,10 +190,20 @@ class TransformersSimilarityRanker:
with torch.inference_mode():
similarity_scores = self.model(**features).logits.squeeze(dim=1) # type: ignore
if scale_score:
similarity_scores = torch.sigmoid(similarity_scores * calibration_factor)
_, sorted_indices = torch.sort(similarity_scores, descending=True)
sorted_indices = sorted_indices.cpu().tolist() # type: ignore
similarity_scores = similarity_scores.cpu().tolist()
ranked_docs = []
for sorted_index_tensor in sorted_indices:
i = sorted_index_tensor.item()
documents[i].score = similarity_scores[i].item()
for sorted_index in sorted_indices:
i = sorted_index
documents[i].score = similarity_scores[i]
ranked_docs.append(documents[i])
if score_threshold is not None:
ranked_docs = [doc for doc in ranked_docs if doc.score >= score_threshold]
return {"documents": ranked_docs[:top_k]}

View File

@ -0,0 +1,6 @@
---
enhancements:
- |
Adds scale_score, which allows users to toggle if they would like their document scores to be raw logits or scaled between 0 and 1 (using the sigmoid function). This is a feature that already existed in Haystack v1 that is being moved over.
Adds calibration_factor. This follows the example from the ExtractiveReader which allows the user to better control the spread of scores when scaling the score using sigmoid.
Adds score_threshold. Also copied from the ExtractiveReader. This optionally allows users to set a score threshold where only documents with a score above this threshold are returned.

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
import torch
from transformers.modeling_outputs import SequenceClassifierOutput
from haystack import Document, ComponentError
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker
@ -19,6 +20,9 @@ class TestSimilarityRanker:
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"meta_fields_to_embed": [],
"embedding_separator": "\n",
"scale_score": True,
"calibration_factor": 1.0,
"score_threshold": None,
"model_kwargs": {},
},
}
@ -29,6 +33,9 @@ class TestSimilarityRanker:
device="cuda",
token="my_token",
top_k=5,
scale_score=False,
calibration_factor=None,
score_threshold=0.01,
model_kwargs={"torch_dtype": "auto"},
)
data = component.to_dict()
@ -41,13 +48,18 @@ class TestSimilarityRanker:
"top_k": 5,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
"scale_score": False,
"calibration_factor": None,
"score_threshold": 0.01,
"model_kwargs": {"torch_dtype": "auto"},
},
}
@patch("torch.sigmoid")
@patch("torch.sort")
def test_embed_meta(self, mocked_sort):
def test_embed_meta(self, mocked_sort, mocked_sigmoid):
mocked_sort.return_value = (None, torch.tensor([0]))
mocked_sigmoid.return_value = torch.tensor([0])
embedder = TransformersSimilarityRanker(
model_name_or_path="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n"
)
@ -71,16 +83,60 @@ class TestSimilarityRanker:
return_tensors="pt",
)
@patch("torch.sort")
def test_scale_score_false(self, mocked_sort):
mocked_sort.return_value = (None, torch.tensor([0, 1]))
embedder = TransformersSimilarityRanker(model_name_or_path="model", scale_score=False)
embedder.model = MagicMock()
embedder.model.return_value = SequenceClassifierOutput(
loss=None, logits=torch.FloatTensor([[-10.6859], [-8.9874]]), hidden_states=None, attentions=None
)
embedder.tokenizer = MagicMock()
documents = [Document(content="document number 0"), Document(content="document number 1")]
out = embedder.run(query="test", documents=documents)
assert out["documents"][0].score == pytest.approx(-10.6859, abs=1e-4)
assert out["documents"][1].score == pytest.approx(-8.9874, abs=1e-4)
@patch("torch.sort")
def test_score_threshold(self, mocked_sort):
mocked_sort.return_value = (None, torch.tensor([0, 1]))
embedder = TransformersSimilarityRanker(model_name_or_path="model", scale_score=False, score_threshold=0.1)
embedder.model = MagicMock()
embedder.model.return_value = SequenceClassifierOutput(
loss=None, logits=torch.FloatTensor([[0.955], [0.001]]), hidden_states=None, attentions=None
)
embedder.tokenizer = MagicMock()
documents = [Document(content="document number 0"), Document(content="document number 1")]
out = embedder.run(query="test", documents=documents)
assert len(out["documents"]) == 1
@pytest.mark.integration
@pytest.mark.parametrize(
"query,docs_before_texts,expected_first_text",
"query,docs_before_texts,expected_first_text,scores",
[
("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"),
("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"),
("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"),
(
"City in Bosnia and Herzegovina",
["Berlin", "Belgrade", "Sarajevo"],
"Sarajevo",
[2.2864143829792738e-05, 0.00012495707778725773, 0.009869757108390331],
),
(
"Machine learning",
["Python", "Bakery in Paris", "Tesla Giga Berlin"],
"Python",
[1.9063229046878405e-05, 1.434577916370472e-05, 1.3049247172602918e-05],
),
(
"Cubist movement",
["Nirvana", "Pablo Picasso", "Coffee"],
"Pablo Picasso",
[1.3313065210240893e-05, 9.90335684036836e-05, 1.3518535524781328e-05],
),
],
)
def test_run(self, query, docs_before_texts, expected_first_text):
def test_run(self, query, docs_before_texts, expected_first_text, scores):
"""
Test if the component ranks documents correctly.
"""
@ -93,8 +149,10 @@ class TestSimilarityRanker:
assert len(docs_after) == 3
assert docs_after[0].content == expected_first_text
sorted_scores = sorted([doc.score for doc in docs_after], reverse=True)
assert [doc.score for doc in docs_after] == sorted_scores
sorted_scores = sorted(scores, reverse=True)
assert docs_after[0].score == pytest.approx(sorted_scores[0], abs=1e-6)
assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6)
assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6)
# Returns an empty list if no documents are provided
@pytest.mark.integration