diff --git a/haystack/components/rankers/transformers_similarity.py b/haystack/components/rankers/transformers_similarity.py index b2f0c7ae3..c93e9b3af 100644 --- a/haystack/components/rankers/transformers_similarity.py +++ b/haystack/components/rankers/transformers_similarity.py @@ -173,6 +173,7 @@ class TransformersSimilarityRanker: score_threshold=self.score_threshold, model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, + batch_size=self.batch_size, ) serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) diff --git a/releasenotes/notes/fix-to-dict-transformer-ranker-f981b37f67f6eec5.yaml b/releasenotes/notes/fix-to-dict-transformer-ranker-f981b37f67f6eec5.yaml new file mode 100644 index 000000000..fe39113b8 --- /dev/null +++ b/releasenotes/notes/fix-to-dict-transformer-ranker-f981b37f67f6eec5.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + The batch_size parameter has now been added to to_dict function of TransformersSimilarityRanker. This means serialization of batch_size now works as expected. diff --git a/test/components/rankers/test_transformers_similarity.py b/test/components/rankers/test_transformers_similarity.py index 616bfa664..277c165fe 100644 --- a/test/components/rankers/test_transformers_similarity.py +++ b/test/components/rankers/test_transformers_similarity.py @@ -34,6 +34,7 @@ class TestSimilarityRanker: "score_threshold": None, "model_kwargs": {"device_map": ComponentDevice.resolve_device(None).to_hf()}, "tokenizer_kwargs": {}, + "batch_size": 16, }, } @@ -50,6 +51,7 @@ class TestSimilarityRanker: score_threshold=0.01, model_kwargs={"torch_dtype": torch.float16}, tokenizer_kwargs={"model_max_length": 512}, + batch_size=32, ) data = component.to_dict() assert data == { @@ -71,6 +73,7 @@ class TestSimilarityRanker: "device_map": ComponentDevice.from_str("cuda:0").to_hf(), }, # torch_dtype is correctly serialized "tokenizer_kwargs": {"model_max_length": 512}, + "batch_size": 32, }, } @@ -106,6 +109,7 @@ class TestSimilarityRanker: "device_map": ComponentDevice.resolve_device(None).to_hf(), }, "tokenizer_kwargs": {}, + "batch_size": 16, }, } @@ -137,6 +141,7 @@ class TestSimilarityRanker: "score_threshold": None, "model_kwargs": {"device_map": expected}, "tokenizer_kwargs": {}, + "batch_size": 16, }, } @@ -157,6 +162,7 @@ class TestSimilarityRanker: "score_threshold": 0.01, "model_kwargs": {"torch_dtype": "torch.float16"}, "tokenizer_kwargs": {"model_max_length": 512}, + "batch_size": 32, }, } @@ -178,6 +184,7 @@ class TestSimilarityRanker: "device_map": ComponentDevice.resolve_device(None).to_hf(), } assert component.tokenizer_kwargs == {"model_max_length": 512} + assert component.batch_size == 32 def test_from_dict_no_default_parameters(self): data = { @@ -199,6 +206,8 @@ class TestSimilarityRanker: assert component.score_threshold is None # torch_dtype is correctly deserialized assert component.model_kwargs == {"device_map": ComponentDevice.resolve_device(None).to_hf()} + assert component.tokenizer_kwargs == {} + assert component.batch_size == 16 @patch("torch.sigmoid") @patch("torch.sort")