mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-14 17:13:03 +00:00
fix: Add batch_size to to_dict of TransformersSimilarityRanker (#9248)
* Add missing batch_size to to_dict of similarity ranker * Add reno
This commit is contained in:
parent
a0d43fdc6e
commit
cdc53cae78
@ -173,6 +173,7 @@ class TransformersSimilarityRanker:
|
||||
score_threshold=self.score_threshold,
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
|
||||
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
The batch_size parameter has now been added to to_dict function of TransformersSimilarityRanker. This means serialization of batch_size now works as expected.
|
||||
@ -34,6 +34,7 @@ class TestSimilarityRanker:
|
||||
"score_threshold": None,
|
||||
"model_kwargs": {"device_map": ComponentDevice.resolve_device(None).to_hf()},
|
||||
"tokenizer_kwargs": {},
|
||||
"batch_size": 16,
|
||||
},
|
||||
}
|
||||
|
||||
@ -50,6 +51,7 @@ class TestSimilarityRanker:
|
||||
score_threshold=0.01,
|
||||
model_kwargs={"torch_dtype": torch.float16},
|
||||
tokenizer_kwargs={"model_max_length": 512},
|
||||
batch_size=32,
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
@ -71,6 +73,7 @@ class TestSimilarityRanker:
|
||||
"device_map": ComponentDevice.from_str("cuda:0").to_hf(),
|
||||
}, # torch_dtype is correctly serialized
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"batch_size": 32,
|
||||
},
|
||||
}
|
||||
|
||||
@ -106,6 +109,7 @@ class TestSimilarityRanker:
|
||||
"device_map": ComponentDevice.resolve_device(None).to_hf(),
|
||||
},
|
||||
"tokenizer_kwargs": {},
|
||||
"batch_size": 16,
|
||||
},
|
||||
}
|
||||
|
||||
@ -137,6 +141,7 @@ class TestSimilarityRanker:
|
||||
"score_threshold": None,
|
||||
"model_kwargs": {"device_map": expected},
|
||||
"tokenizer_kwargs": {},
|
||||
"batch_size": 16,
|
||||
},
|
||||
}
|
||||
|
||||
@ -157,6 +162,7 @@ class TestSimilarityRanker:
|
||||
"score_threshold": 0.01,
|
||||
"model_kwargs": {"torch_dtype": "torch.float16"},
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"batch_size": 32,
|
||||
},
|
||||
}
|
||||
|
||||
@ -178,6 +184,7 @@ class TestSimilarityRanker:
|
||||
"device_map": ComponentDevice.resolve_device(None).to_hf(),
|
||||
}
|
||||
assert component.tokenizer_kwargs == {"model_max_length": 512}
|
||||
assert component.batch_size == 32
|
||||
|
||||
def test_from_dict_no_default_parameters(self):
|
||||
data = {
|
||||
@ -199,6 +206,8 @@ class TestSimilarityRanker:
|
||||
assert component.score_threshold is None
|
||||
# torch_dtype is correctly deserialized
|
||||
assert component.model_kwargs == {"device_map": ComponentDevice.resolve_device(None).to_hf()}
|
||||
assert component.tokenizer_kwargs == {}
|
||||
assert component.batch_size == 16
|
||||
|
||||
@patch("torch.sigmoid")
|
||||
@patch("torch.sort")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user