mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-05 11:27:15 +00:00
fix: Fix from_dict methods of components using HF models to work with default values (#8003)
* Fix from_dict to work if device isn't provided in init params * Minor refactoring of from_dict for components that load HF models * Add tests * Update tests to test loading with all default parameters * Add more tests * Add release notes * Add unit test for whisper local * Update reno * Add fix for ExtractiveReader * Fix NamedEntityExtractor
This commit is contained in:
parent
f19131f13a
commit
c121c86c4c
@ -104,7 +104,7 @@ class LocalWhisperTranscriber:
|
||||
The deserialized component.
|
||||
"""
|
||||
init_params = data["init_parameters"]
|
||||
if init_params["device"] is not None:
|
||||
if init_params.get("device") is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
|
||||
@ -126,9 +126,9 @@ class SentenceTransformersDocumentEmbedder:
|
||||
Deserialized component.
|
||||
"""
|
||||
init_params = data["init_parameters"]
|
||||
if init_params["device"] is not None:
|
||||
if init_params.get("device") is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
deserialize_secrets_inplace(init_params, keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
|
||||
@ -116,9 +116,9 @@ class SentenceTransformersTextEmbedder:
|
||||
Deserialized component.
|
||||
"""
|
||||
init_params = data["init_parameters"]
|
||||
if init_params["device"] is not None:
|
||||
if init_params.get("device") is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
deserialize_secrets_inplace(init_params, keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
|
||||
@ -221,7 +221,7 @@ class NamedEntityExtractor:
|
||||
"""
|
||||
try:
|
||||
init_params = data["init_parameters"]
|
||||
if init_params["device"] is not None:
|
||||
if init_params.get("device") is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@ -142,9 +142,9 @@ class SentenceTransformersDiversityRanker:
|
||||
The deserialized component.
|
||||
"""
|
||||
init_params = data["init_parameters"]
|
||||
if init_params["device"] is not None:
|
||||
if init_params.get("device") is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
deserialize_secrets_inplace(init_params, keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
||||
|
||||
@ -176,11 +176,12 @@ class TransformersSimilarityRanker:
|
||||
:returns:
|
||||
Deserialized component.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
init_params = data["init_parameters"]
|
||||
if init_params["device"] is not None:
|
||||
if init_params.get("device") is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
deserialize_hf_model_kwargs(init_params["model_kwargs"])
|
||||
if init_params.get("model_kwargs") is not None:
|
||||
deserialize_hf_model_kwargs(init_params["model_kwargs"])
|
||||
deserialize_secrets_inplace(init_params, keys=["token"])
|
||||
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
|
||||
@ -170,10 +170,11 @@ class ExtractiveReader:
|
||||
Deserialized component.
|
||||
"""
|
||||
init_params = data["init_parameters"]
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
if init_params["device"] is not None:
|
||||
deserialize_secrets_inplace(init_params, keys=["token"])
|
||||
if init_params.get("device") is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
deserialize_hf_model_kwargs(init_params["model_kwargs"])
|
||||
if init_params.get("model_kwargs") is not None:
|
||||
deserialize_hf_model_kwargs(init_params["model_kwargs"])
|
||||
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
fixes:
|
||||
- |
|
||||
This updates the components, TransformersSimilarityRanker, SentenceTransformersDiversityRanker, SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder and LocalWhisperTranscriber from_dict methods to work when loading with init_parameters only containing required parameters.
|
||||
@ -74,6 +74,13 @@ class TestLocalWhisperTranscriber:
|
||||
assert transcriber.whisper_params == {}
|
||||
assert transcriber._model is None
|
||||
|
||||
def test_from_dict_no_default_parameters(self):
|
||||
data = {"type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber", "init_parameters": {}}
|
||||
transcriber = LocalWhisperTranscriber.from_dict(data)
|
||||
assert transcriber.model == "large"
|
||||
assert transcriber.device == ComponentDevice.resolve_device(None)
|
||||
assert transcriber.whisper_params == {}
|
||||
|
||||
def test_from_dict_none_device(self):
|
||||
data = {
|
||||
"type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber",
|
||||
|
||||
@ -137,6 +137,25 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert component.trust_remote_code
|
||||
assert component.meta_fields_to_embed == ["meta_field"]
|
||||
|
||||
def test_from_dict_no_default_parameters(self):
|
||||
component = SentenceTransformersDocumentEmbedder.from_dict(
|
||||
{
|
||||
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
|
||||
"init_parameters": {},
|
||||
}
|
||||
)
|
||||
assert component.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
assert component.device == ComponentDevice.resolve_device(None)
|
||||
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert component.prefix == ""
|
||||
assert component.suffix == ""
|
||||
assert component.batch_size == 32
|
||||
assert component.progress_bar is True
|
||||
assert component.normalize_embeddings is False
|
||||
assert component.embedding_separator == "\n"
|
||||
assert component.trust_remote_code is False
|
||||
assert component.meta_fields_to_embed == []
|
||||
|
||||
def test_from_dict_none_device(self):
|
||||
init_parameters = {
|
||||
"model": "model",
|
||||
|
||||
@ -122,6 +122,22 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert component.normalize_embeddings is False
|
||||
assert component.trust_remote_code is False
|
||||
|
||||
def test_from_dict_no_default_parameters(self):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
||||
"init_parameters": {},
|
||||
}
|
||||
component = SentenceTransformersTextEmbedder.from_dict(data)
|
||||
assert component.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
assert component.device == ComponentDevice.resolve_device(None)
|
||||
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert component.prefix == ""
|
||||
assert component.suffix == ""
|
||||
assert component.batch_size == 32
|
||||
assert component.progress_bar is True
|
||||
assert component.normalize_embeddings is False
|
||||
assert component.trust_remote_code is False
|
||||
|
||||
def test_from_dict_none_device(self):
|
||||
data = {
|
||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
||||
|
||||
@ -40,6 +40,17 @@ def test_named_entity_extractor_serde():
|
||||
_ = NamedEntityExtractor.from_dict(serde_data)
|
||||
|
||||
|
||||
def test_named_entity_extractor_from_dict_no_default_parameters_hf():
|
||||
data = {
|
||||
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
|
||||
"init_parameters": {"backend": "HUGGING_FACE", "model": "dslim/bert-base-NER"},
|
||||
}
|
||||
extractor = NamedEntityExtractor.from_dict(data)
|
||||
|
||||
assert extractor._backend.model_name == "dslim/bert-base-NER"
|
||||
assert extractor._backend.device == ComponentDevice.resolve_device(None)
|
||||
|
||||
|
||||
# tests for NamedEntityExtractor serialization/deserialization in a pipeline
|
||||
def test_named_entity_extractor_pipeline_serde(tmp_path):
|
||||
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
|
||||
|
||||
@ -144,6 +144,25 @@ class TestSentenceTransformersDiversityRanker:
|
||||
assert ranker.meta_fields_to_embed == []
|
||||
assert ranker.embedding_separator == "\n"
|
||||
|
||||
def test_from_dict_no_default_parameters(self):
|
||||
data = {
|
||||
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
|
||||
"init_parameters": {},
|
||||
}
|
||||
ranker = SentenceTransformersDiversityRanker.from_dict(data)
|
||||
|
||||
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert ranker.top_k == 10
|
||||
assert ranker.device == ComponentDevice.resolve_device(None)
|
||||
assert ranker.similarity == "cosine"
|
||||
assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert ranker.query_prefix == ""
|
||||
assert ranker.document_prefix == ""
|
||||
assert ranker.query_suffix == ""
|
||||
assert ranker.document_suffix == ""
|
||||
assert ranker.meta_fields_to_embed == []
|
||||
assert ranker.embedding_separator == "\n"
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/msmarco-distilbert-base-v4",
|
||||
|
||||
@ -172,6 +172,27 @@ class TestSimilarityRanker:
|
||||
"device_map": ComponentDevice.resolve_device(None).to_hf(),
|
||||
}
|
||||
|
||||
def test_from_dict_no_default_parameters(self):
|
||||
data = {
|
||||
"type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker",
|
||||
"init_parameters": {},
|
||||
}
|
||||
|
||||
component = TransformersSimilarityRanker.from_dict(data)
|
||||
assert component.device is None
|
||||
assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert component.top_k == 10
|
||||
assert component.query_prefix == ""
|
||||
assert component.document_prefix == ""
|
||||
assert component.meta_fields_to_embed == []
|
||||
assert component.embedding_separator == "\n"
|
||||
assert component.scale_score
|
||||
assert component.calibration_factor == 1.0
|
||||
assert component.score_threshold is None
|
||||
# torch_dtype is correctly deserialized
|
||||
assert component.model_kwargs == {"device_map": ComponentDevice.resolve_device(None).to_hf()}
|
||||
|
||||
@patch("torch.sigmoid")
|
||||
@patch("torch.sort")
|
||||
def test_embed_meta(self, mocked_sort, mocked_sigmoid):
|
||||
|
||||
@ -243,6 +243,25 @@ def test_from_dict():
|
||||
}
|
||||
|
||||
|
||||
def test_from_dict_no_default_parameters():
|
||||
data = {"type": "haystack.components.readers.extractive.ExtractiveReader", "init_parameters": {}}
|
||||
|
||||
component = ExtractiveReader.from_dict(data)
|
||||
assert component.model_name_or_path == "deepset/roberta-base-squad2-distilled"
|
||||
assert component.device is None
|
||||
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert component.top_k == 20
|
||||
assert component.score_threshold is None
|
||||
assert component.max_seq_length == 384
|
||||
assert component.stride == 128
|
||||
assert component.max_batch_size is None
|
||||
assert component.answers_per_seq is None
|
||||
assert component.no_answer
|
||||
assert component.calibration_factor == 0.1
|
||||
assert component.overlap_threshold == 0.01
|
||||
assert component.model_kwargs == {"device_map": ComponentDevice.resolve_device(None).to_hf()}
|
||||
|
||||
|
||||
def test_from_dict_no_token():
|
||||
data = {
|
||||
"type": "haystack.components.readers.extractive.ExtractiveReader",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user