mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
feat: TransformerSimilarityRanker add batching across Documents during inference (#8344)
* First pass at adding batch support to TransformersSimilarityRanker * Add test * Add reno
This commit is contained in:
parent
675cf43be7
commit
7227bcf9df
@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
|
||||
import accelerate # pylint: disable=unused-import # the library is used but not directly referenced
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
|
||||
@ -42,7 +43,7 @@ class TransformersSimilarityRanker:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
device: Optional[ComponentDevice] = None,
|
||||
@ -57,6 +58,7 @@ class TransformersSimilarityRanker:
|
||||
score_threshold: Optional[float] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
"""
|
||||
Creates an instance of TransformersSimilarityRanker.
|
||||
@ -93,6 +95,9 @@ class TransformersSimilarityRanker:
|
||||
:param tokenizer_kwargs:
|
||||
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
|
||||
Refer to specific model documentation for available kwargs.
|
||||
:param batch_size:
|
||||
The batch size to use for inference. The higher the batch size, the more memory is required.
|
||||
If you run into memory issues, reduce the batch size.
|
||||
|
||||
:raises ValueError:
|
||||
If `top_k` is not > 0.
|
||||
@ -117,6 +122,7 @@ class TransformersSimilarityRanker:
|
||||
model_kwargs = resolve_hf_device_map(device=device, model_kwargs=model_kwargs)
|
||||
self.model_kwargs = model_kwargs
|
||||
self.tokenizer_kwargs = tokenizer_kwargs or {}
|
||||
self.batch_size = batch_size
|
||||
|
||||
# Parameter validation
|
||||
if self.scale_score and self.calibration_factor is None:
|
||||
@ -261,11 +267,28 @@ class TransformersSimilarityRanker:
|
||||
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""])
|
||||
query_doc_pairs.append([self.query_prefix + query, self.document_prefix + text_to_embed])
|
||||
|
||||
features = self.tokenizer(query_doc_pairs, padding=True, truncation=True, return_tensors="pt").to( # type: ignore
|
||||
class _Dataset(Dataset):
|
||||
def __init__(self, batch_encoding):
|
||||
self.batch_encoding = batch_encoding
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_encoding["input_ids"])
|
||||
|
||||
def __getitem__(self, item):
|
||||
return {key: self.batch_encoding.data[key][item] for key in self.batch_encoding.data.keys()}
|
||||
|
||||
batch_enc = self.tokenizer(query_doc_pairs, padding=True, truncation=True, return_tensors="pt").to( # type: ignore
|
||||
self.device.first_device.to_torch()
|
||||
)
|
||||
dataset = _Dataset(batch_enc)
|
||||
inp_dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
similarity_scores = []
|
||||
with torch.inference_mode():
|
||||
similarity_scores = self.model(**features).logits.squeeze(dim=1) # type: ignore
|
||||
for features in inp_dataloader:
|
||||
model_preds = self.model(**features).logits.squeeze(dim=1) # type: ignore
|
||||
similarity_scores.extend(model_preds)
|
||||
similarity_scores = torch.stack(similarity_scores)
|
||||
|
||||
if scale_score:
|
||||
similarity_scores = torch.sigmoid(similarity_scores * calibration_factor)
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
We added batching during inference time to the TransformerSimilarityRanker to help prevent OOMs when ranking large amounts of Documents.
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
import torch
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
|
||||
from haystack import ComponentError, Document
|
||||
from haystack import Document
|
||||
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.device import ComponentDevice, DeviceMap
|
||||
@ -202,7 +202,9 @@ class TestSimilarityRanker:
|
||||
|
||||
@patch("torch.sigmoid")
|
||||
@patch("torch.sort")
|
||||
def test_embed_meta(self, mocked_sort, mocked_sigmoid):
|
||||
@patch("torch.stack")
|
||||
def test_embed_meta(self, mocked_stack, mocked_sort, mocked_sigmoid):
|
||||
mocked_stack.return_value = torch.tensor([0])
|
||||
mocked_sort.return_value = (None, torch.tensor([0]))
|
||||
mocked_sigmoid.return_value = torch.tensor([0])
|
||||
embedder = TransformersSimilarityRanker(
|
||||
@ -232,7 +234,9 @@ class TestSimilarityRanker:
|
||||
|
||||
@patch("torch.sigmoid")
|
||||
@patch("torch.sort")
|
||||
def test_prefix(self, mocked_sort, mocked_sigmoid):
|
||||
@patch("torch.stack")
|
||||
def test_prefix(self, mocked_stack, mocked_sort, mocked_sigmoid):
|
||||
mocked_stack.return_value = torch.tensor([0])
|
||||
mocked_sort.return_value = (None, torch.tensor([0]))
|
||||
mocked_sigmoid.return_value = torch.tensor([0])
|
||||
embedder = TransformersSimilarityRanker(
|
||||
@ -261,7 +265,9 @@ class TestSimilarityRanker:
|
||||
)
|
||||
|
||||
@patch("torch.sort")
|
||||
def test_scale_score_false(self, mocked_sort):
|
||||
@patch("torch.stack")
|
||||
def test_scale_score_false(self, mocked_stack, mocked_sort):
|
||||
mocked_stack.return_value = torch.FloatTensor([-10.6859, -8.9874])
|
||||
mocked_sort.return_value = (None, torch.tensor([0, 1]))
|
||||
embedder = TransformersSimilarityRanker(model="model", scale_score=False)
|
||||
embedder.model = MagicMock()
|
||||
@ -277,7 +283,9 @@ class TestSimilarityRanker:
|
||||
assert out["documents"][1].score == pytest.approx(-8.9874, abs=1e-4)
|
||||
|
||||
@patch("torch.sort")
|
||||
def test_score_threshold(self, mocked_sort):
|
||||
@patch("torch.stack")
|
||||
def test_score_threshold(self, mocked_stack, mocked_sort):
|
||||
mocked_stack.return_value = torch.FloatTensor([0.955, 0.001])
|
||||
mocked_sort.return_value = (None, torch.tensor([0, 1]))
|
||||
embedder = TransformersSimilarityRanker(model="model", scale_score=False, score_threshold=0.1)
|
||||
embedder.model = MagicMock()
|
||||
@ -359,6 +367,48 @@ class TestSimilarityRanker:
|
||||
assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6)
|
||||
assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6)
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"query,docs_before_texts,expected_first_text,scores",
|
||||
[
|
||||
(
|
||||
"City in Bosnia and Herzegovina",
|
||||
["Berlin", "Belgrade", "Sarajevo"],
|
||||
"Sarajevo",
|
||||
[2.2864143829792738e-05, 0.00012495707778725773, 0.009869757108390331],
|
||||
),
|
||||
(
|
||||
"Machine learning",
|
||||
["Python", "Bakery in Paris", "Tesla Giga Berlin"],
|
||||
"Python",
|
||||
[1.9063229046878405e-05, 1.434577916370472e-05, 1.3049247172602918e-05],
|
||||
),
|
||||
(
|
||||
"Cubist movement",
|
||||
["Nirvana", "Pablo Picasso", "Coffee"],
|
||||
"Pablo Picasso",
|
||||
[1.3313065210240893e-05, 9.90335684036836e-05, 1.3518535524781328e-05],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_run_small_batch_size(self, query, docs_before_texts, expected_first_text, scores):
|
||||
"""
|
||||
Test if the component ranks documents correctly.
|
||||
"""
|
||||
ranker = TransformersSimilarityRanker(model="cross-encoder/ms-marco-MiniLM-L-6-v2", batch_size=2)
|
||||
ranker.warm_up()
|
||||
docs_before = [Document(content=text) for text in docs_before_texts]
|
||||
output = ranker.run(query=query, documents=docs_before)
|
||||
docs_after = output["documents"]
|
||||
|
||||
assert len(docs_after) == 3
|
||||
assert docs_after[0].content == expected_first_text
|
||||
|
||||
sorted_scores = sorted(scores, reverse=True)
|
||||
assert docs_after[0].score == pytest.approx(sorted_scores[0], abs=1e-6)
|
||||
assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6)
|
||||
assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6)
|
||||
|
||||
# Returns an empty list if no documents are provided
|
||||
@pytest.mark.integration
|
||||
def test_returns_empty_list_if_no_documents_are_provided(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user