fix: Add batch_size to to_dict of TransformersSimilarityRanker (#9248)

* Add missing batch_size to to_dict of similarity ranker

* Add reno
This commit is contained in:
Sebastian Husch Lee 2025-04-16 12:16:59 +02:00 committed by GitHub
parent a0d43fdc6e
commit cdc53cae78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 0 deletions

View File

@ -173,6 +173,7 @@ class TransformersSimilarityRanker:
score_threshold=self.score_threshold, score_threshold=self.score_threshold,
model_kwargs=self.model_kwargs, model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs, tokenizer_kwargs=self.tokenizer_kwargs,
batch_size=self.batch_size,
) )
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])

View File

@ -0,0 +1,4 @@
---
fixes:
- |
The batch_size parameter has now been added to to_dict function of TransformersSimilarityRanker. This means serialization of batch_size now works as expected.

View File

@ -34,6 +34,7 @@ class TestSimilarityRanker:
"score_threshold": None, "score_threshold": None,
"model_kwargs": {"device_map": ComponentDevice.resolve_device(None).to_hf()}, "model_kwargs": {"device_map": ComponentDevice.resolve_device(None).to_hf()},
"tokenizer_kwargs": {}, "tokenizer_kwargs": {},
"batch_size": 16,
}, },
} }
@ -50,6 +51,7 @@ class TestSimilarityRanker:
score_threshold=0.01, score_threshold=0.01,
model_kwargs={"torch_dtype": torch.float16}, model_kwargs={"torch_dtype": torch.float16},
tokenizer_kwargs={"model_max_length": 512}, tokenizer_kwargs={"model_max_length": 512},
batch_size=32,
) )
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
@ -71,6 +73,7 @@ class TestSimilarityRanker:
"device_map": ComponentDevice.from_str("cuda:0").to_hf(), "device_map": ComponentDevice.from_str("cuda:0").to_hf(),
}, # torch_dtype is correctly serialized }, # torch_dtype is correctly serialized
"tokenizer_kwargs": {"model_max_length": 512}, "tokenizer_kwargs": {"model_max_length": 512},
"batch_size": 32,
}, },
} }
@ -106,6 +109,7 @@ class TestSimilarityRanker:
"device_map": ComponentDevice.resolve_device(None).to_hf(), "device_map": ComponentDevice.resolve_device(None).to_hf(),
}, },
"tokenizer_kwargs": {}, "tokenizer_kwargs": {},
"batch_size": 16,
}, },
} }
@ -137,6 +141,7 @@ class TestSimilarityRanker:
"score_threshold": None, "score_threshold": None,
"model_kwargs": {"device_map": expected}, "model_kwargs": {"device_map": expected},
"tokenizer_kwargs": {}, "tokenizer_kwargs": {},
"batch_size": 16,
}, },
} }
@ -157,6 +162,7 @@ class TestSimilarityRanker:
"score_threshold": 0.01, "score_threshold": 0.01,
"model_kwargs": {"torch_dtype": "torch.float16"}, "model_kwargs": {"torch_dtype": "torch.float16"},
"tokenizer_kwargs": {"model_max_length": 512}, "tokenizer_kwargs": {"model_max_length": 512},
"batch_size": 32,
}, },
} }
@ -178,6 +184,7 @@ class TestSimilarityRanker:
"device_map": ComponentDevice.resolve_device(None).to_hf(), "device_map": ComponentDevice.resolve_device(None).to_hf(),
} }
assert component.tokenizer_kwargs == {"model_max_length": 512} assert component.tokenizer_kwargs == {"model_max_length": 512}
assert component.batch_size == 32
def test_from_dict_no_default_parameters(self): def test_from_dict_no_default_parameters(self):
data = { data = {
@ -199,6 +206,8 @@ class TestSimilarityRanker:
assert component.score_threshold is None assert component.score_threshold is None
# torch_dtype is correctly deserialized # torch_dtype is correctly deserialized
assert component.model_kwargs == {"device_map": ComponentDevice.resolve_device(None).to_hf()} assert component.model_kwargs == {"device_map": ComponentDevice.resolve_device(None).to_hf()}
assert component.tokenizer_kwargs == {}
assert component.batch_size == 16
@patch("torch.sigmoid") @patch("torch.sigmoid")
@patch("torch.sort") @patch("torch.sort")