diff --git a/haystack/components/rankers/transformers_similarity.py b/haystack/components/rankers/transformers_similarity.py index ca9cd2519..d380768d3 100644 --- a/haystack/components/rankers/transformers_similarity.py +++ b/haystack/components/rankers/transformers_similarity.py @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: import accelerate # pylint: disable=unused-import # the library is used but not directly referenced import torch + from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer @@ -42,7 +43,7 @@ class TransformersSimilarityRanker: ``` """ - def __init__( + def __init__( # noqa: PLR0913 self, model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: Optional[ComponentDevice] = None, @@ -57,6 +58,7 @@ class TransformersSimilarityRanker: score_threshold: Optional[float] = None, model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, + batch_size: int = 16, ): """ Creates an instance of TransformersSimilarityRanker. @@ -93,6 +95,9 @@ class TransformersSimilarityRanker: :param tokenizer_kwargs: Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer. Refer to specific model documentation for available kwargs. + :param batch_size: + The batch size to use for inference. The higher the batch size, the more memory is required. + If you run into memory issues, reduce the batch size. :raises ValueError: If `top_k` is not > 0. @@ -117,6 +122,7 @@ class TransformersSimilarityRanker: model_kwargs = resolve_hf_device_map(device=device, model_kwargs=model_kwargs) self.model_kwargs = model_kwargs self.tokenizer_kwargs = tokenizer_kwargs or {} + self.batch_size = batch_size # Parameter validation if self.scale_score and self.calibration_factor is None: @@ -261,11 +267,28 @@ class TransformersSimilarityRanker: text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) query_doc_pairs.append([self.query_prefix + query, self.document_prefix + text_to_embed]) - features = self.tokenizer(query_doc_pairs, padding=True, truncation=True, return_tensors="pt").to( # type: ignore + class _Dataset(Dataset): + def __init__(self, batch_encoding): + self.batch_encoding = batch_encoding + + def __len__(self): + return len(self.batch_encoding["input_ids"]) + + def __getitem__(self, item): + return {key: self.batch_encoding.data[key][item] for key in self.batch_encoding.data.keys()} + + batch_enc = self.tokenizer(query_doc_pairs, padding=True, truncation=True, return_tensors="pt").to( # type: ignore self.device.first_device.to_torch() ) + dataset = _Dataset(batch_enc) + inp_dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) + + similarity_scores = [] with torch.inference_mode(): - similarity_scores = self.model(**features).logits.squeeze(dim=1) # type: ignore + for features in inp_dataloader: + model_preds = self.model(**features).logits.squeeze(dim=1) # type: ignore + similarity_scores.extend(model_preds) + similarity_scores = torch.stack(similarity_scores) if scale_score: similarity_scores = torch.sigmoid(similarity_scores * calibration_factor) diff --git a/releasenotes/notes/ranker-add-batching-during-inference-f077411ec389a63b.yaml b/releasenotes/notes/ranker-add-batching-during-inference-f077411ec389a63b.yaml new file mode 100644 index 000000000..24e8191ac --- /dev/null +++ b/releasenotes/notes/ranker-add-batching-during-inference-f077411ec389a63b.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + We added batching during inference time to the TransformerSimilarityRanker to help prevent OOMs when ranking large amounts of Documents. diff --git a/test/components/rankers/test_transformers_similarity.py b/test/components/rankers/test_transformers_similarity.py index 45e564b0f..6031d85e1 100644 --- a/test/components/rankers/test_transformers_similarity.py +++ b/test/components/rankers/test_transformers_similarity.py @@ -8,7 +8,7 @@ import pytest import torch from transformers.modeling_outputs import SequenceClassifierOutput -from haystack import ComponentError, Document +from haystack import Document from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker from haystack.utils.auth import Secret from haystack.utils.device import ComponentDevice, DeviceMap @@ -202,7 +202,9 @@ class TestSimilarityRanker: @patch("torch.sigmoid") @patch("torch.sort") - def test_embed_meta(self, mocked_sort, mocked_sigmoid): + @patch("torch.stack") + def test_embed_meta(self, mocked_stack, mocked_sort, mocked_sigmoid): + mocked_stack.return_value = torch.tensor([0]) mocked_sort.return_value = (None, torch.tensor([0])) mocked_sigmoid.return_value = torch.tensor([0]) embedder = TransformersSimilarityRanker( @@ -232,7 +234,9 @@ class TestSimilarityRanker: @patch("torch.sigmoid") @patch("torch.sort") - def test_prefix(self, mocked_sort, mocked_sigmoid): + @patch("torch.stack") + def test_prefix(self, mocked_stack, mocked_sort, mocked_sigmoid): + mocked_stack.return_value = torch.tensor([0]) mocked_sort.return_value = (None, torch.tensor([0])) mocked_sigmoid.return_value = torch.tensor([0]) embedder = TransformersSimilarityRanker( @@ -261,7 +265,9 @@ class TestSimilarityRanker: ) @patch("torch.sort") - def test_scale_score_false(self, mocked_sort): + @patch("torch.stack") + def test_scale_score_false(self, mocked_stack, mocked_sort): + mocked_stack.return_value = torch.FloatTensor([-10.6859, -8.9874]) mocked_sort.return_value = (None, torch.tensor([0, 1])) embedder = TransformersSimilarityRanker(model="model", scale_score=False) embedder.model = MagicMock() @@ -277,7 +283,9 @@ class TestSimilarityRanker: assert out["documents"][1].score == pytest.approx(-8.9874, abs=1e-4) @patch("torch.sort") - def test_score_threshold(self, mocked_sort): + @patch("torch.stack") + def test_score_threshold(self, mocked_stack, mocked_sort): + mocked_stack.return_value = torch.FloatTensor([0.955, 0.001]) mocked_sort.return_value = (None, torch.tensor([0, 1])) embedder = TransformersSimilarityRanker(model="model", scale_score=False, score_threshold=0.1) embedder.model = MagicMock() @@ -359,6 +367,48 @@ class TestSimilarityRanker: assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6) assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6) + @pytest.mark.integration + @pytest.mark.parametrize( + "query,docs_before_texts,expected_first_text,scores", + [ + ( + "City in Bosnia and Herzegovina", + ["Berlin", "Belgrade", "Sarajevo"], + "Sarajevo", + [2.2864143829792738e-05, 0.00012495707778725773, 0.009869757108390331], + ), + ( + "Machine learning", + ["Python", "Bakery in Paris", "Tesla Giga Berlin"], + "Python", + [1.9063229046878405e-05, 1.434577916370472e-05, 1.3049247172602918e-05], + ), + ( + "Cubist movement", + ["Nirvana", "Pablo Picasso", "Coffee"], + "Pablo Picasso", + [1.3313065210240893e-05, 9.90335684036836e-05, 1.3518535524781328e-05], + ), + ], + ) + def test_run_small_batch_size(self, query, docs_before_texts, expected_first_text, scores): + """ + Test if the component ranks documents correctly. + """ + ranker = TransformersSimilarityRanker(model="cross-encoder/ms-marco-MiniLM-L-6-v2", batch_size=2) + ranker.warm_up() + docs_before = [Document(content=text) for text in docs_before_texts] + output = ranker.run(query=query, documents=docs_before) + docs_after = output["documents"] + + assert len(docs_after) == 3 + assert docs_after[0].content == expected_first_text + + sorted_scores = sorted(scores, reverse=True) + assert docs_after[0].score == pytest.approx(sorted_scores[0], abs=1e-6) + assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6) + assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6) + # Returns an empty list if no documents are provided @pytest.mark.integration def test_returns_empty_list_if_no_documents_are_provided(self):