mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-14 17:13:03 +00:00
fix: Add batch_size to to_dict of TransformersSimilarityRanker (#9248)
* Add missing batch_size to to_dict of similarity ranker * Add reno
This commit is contained in:
parent
a0d43fdc6e
commit
cdc53cae78
@ -173,6 +173,7 @@ class TransformersSimilarityRanker:
|
|||||||
score_threshold=self.score_threshold,
|
score_threshold=self.score_threshold,
|
||||||
model_kwargs=self.model_kwargs,
|
model_kwargs=self.model_kwargs,
|
||||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||||
|
batch_size=self.batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
fixes:
|
||||||
|
- |
|
||||||
|
The batch_size parameter has now been added to to_dict function of TransformersSimilarityRanker. This means serialization of batch_size now works as expected.
|
||||||
@ -34,6 +34,7 @@ class TestSimilarityRanker:
|
|||||||
"score_threshold": None,
|
"score_threshold": None,
|
||||||
"model_kwargs": {"device_map": ComponentDevice.resolve_device(None).to_hf()},
|
"model_kwargs": {"device_map": ComponentDevice.resolve_device(None).to_hf()},
|
||||||
"tokenizer_kwargs": {},
|
"tokenizer_kwargs": {},
|
||||||
|
"batch_size": 16,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,6 +51,7 @@ class TestSimilarityRanker:
|
|||||||
score_threshold=0.01,
|
score_threshold=0.01,
|
||||||
model_kwargs={"torch_dtype": torch.float16},
|
model_kwargs={"torch_dtype": torch.float16},
|
||||||
tokenizer_kwargs={"model_max_length": 512},
|
tokenizer_kwargs={"model_max_length": 512},
|
||||||
|
batch_size=32,
|
||||||
)
|
)
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
@ -71,6 +73,7 @@ class TestSimilarityRanker:
|
|||||||
"device_map": ComponentDevice.from_str("cuda:0").to_hf(),
|
"device_map": ComponentDevice.from_str("cuda:0").to_hf(),
|
||||||
}, # torch_dtype is correctly serialized
|
}, # torch_dtype is correctly serialized
|
||||||
"tokenizer_kwargs": {"model_max_length": 512},
|
"tokenizer_kwargs": {"model_max_length": 512},
|
||||||
|
"batch_size": 32,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,6 +109,7 @@ class TestSimilarityRanker:
|
|||||||
"device_map": ComponentDevice.resolve_device(None).to_hf(),
|
"device_map": ComponentDevice.resolve_device(None).to_hf(),
|
||||||
},
|
},
|
||||||
"tokenizer_kwargs": {},
|
"tokenizer_kwargs": {},
|
||||||
|
"batch_size": 16,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,6 +141,7 @@ class TestSimilarityRanker:
|
|||||||
"score_threshold": None,
|
"score_threshold": None,
|
||||||
"model_kwargs": {"device_map": expected},
|
"model_kwargs": {"device_map": expected},
|
||||||
"tokenizer_kwargs": {},
|
"tokenizer_kwargs": {},
|
||||||
|
"batch_size": 16,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -157,6 +162,7 @@ class TestSimilarityRanker:
|
|||||||
"score_threshold": 0.01,
|
"score_threshold": 0.01,
|
||||||
"model_kwargs": {"torch_dtype": "torch.float16"},
|
"model_kwargs": {"torch_dtype": "torch.float16"},
|
||||||
"tokenizer_kwargs": {"model_max_length": 512},
|
"tokenizer_kwargs": {"model_max_length": 512},
|
||||||
|
"batch_size": 32,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,6 +184,7 @@ class TestSimilarityRanker:
|
|||||||
"device_map": ComponentDevice.resolve_device(None).to_hf(),
|
"device_map": ComponentDevice.resolve_device(None).to_hf(),
|
||||||
}
|
}
|
||||||
assert component.tokenizer_kwargs == {"model_max_length": 512}
|
assert component.tokenizer_kwargs == {"model_max_length": 512}
|
||||||
|
assert component.batch_size == 32
|
||||||
|
|
||||||
def test_from_dict_no_default_parameters(self):
|
def test_from_dict_no_default_parameters(self):
|
||||||
data = {
|
data = {
|
||||||
@ -199,6 +206,8 @@ class TestSimilarityRanker:
|
|||||||
assert component.score_threshold is None
|
assert component.score_threshold is None
|
||||||
# torch_dtype is correctly deserialized
|
# torch_dtype is correctly deserialized
|
||||||
assert component.model_kwargs == {"device_map": ComponentDevice.resolve_device(None).to_hf()}
|
assert component.model_kwargs == {"device_map": ComponentDevice.resolve_device(None).to_hf()}
|
||||||
|
assert component.tokenizer_kwargs == {}
|
||||||
|
assert component.batch_size == 16
|
||||||
|
|
||||||
@patch("torch.sigmoid")
|
@patch("torch.sigmoid")
|
||||||
@patch("torch.sort")
|
@patch("torch.sort")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user