mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-03 13:23:38 +00:00
fix: Update device deserialization for components that use local models (#7686)
* fix: Update device deserializtion for SentenceTransformersTextEmbedder * Add unit test * Fix unit test * Make same change to doc embedder * Add release notes * Add same change to Diversity Ranker and Named Entity Extractor * Add unit test * Add the same for whisper local * Update release notes
This commit is contained in:
parent
811b93db91
commit
a2be90b95a
@ -90,9 +90,9 @@ class LocalWhisperTranscriber:
|
|||||||
:returns:
|
:returns:
|
||||||
The deserialized component.
|
The deserialized component.
|
||||||
"""
|
"""
|
||||||
serialized_device = data["init_parameters"]["device"]
|
init_params = data["init_parameters"]
|
||||||
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)
|
if init_params["device"] is not None:
|
||||||
|
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
@component.output_types(documents=List[Document])
|
@component.output_types(documents=List[Document])
|
||||||
|
@ -125,9 +125,9 @@ class SentenceTransformersDocumentEmbedder:
|
|||||||
:returns:
|
:returns:
|
||||||
Deserialized component.
|
Deserialized component.
|
||||||
"""
|
"""
|
||||||
serialized_device = data["init_parameters"]["device"]
|
init_params = data["init_parameters"]
|
||||||
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)
|
if init_params["device"] is not None:
|
||||||
|
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
@ -115,9 +115,9 @@ class SentenceTransformersTextEmbedder:
|
|||||||
:returns:
|
:returns:
|
||||||
Deserialized component.
|
Deserialized component.
|
||||||
"""
|
"""
|
||||||
serialized_device = data["init_parameters"]["device"]
|
init_params = data["init_parameters"]
|
||||||
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)
|
if init_params["device"] is not None:
|
||||||
|
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
@ -215,7 +215,8 @@ class NamedEntityExtractor:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
init_params = data["init_parameters"]
|
init_params = data["init_parameters"]
|
||||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
if init_params["device"] is not None:
|
||||||
|
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
|
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
|
||||||
|
@ -141,9 +141,9 @@ class SentenceTransformersDiversityRanker:
|
|||||||
:returns:
|
:returns:
|
||||||
The deserialized component.
|
The deserialized component.
|
||||||
"""
|
"""
|
||||||
serialized_device = data["init_parameters"]["device"]
|
init_params = data["init_parameters"]
|
||||||
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)
|
if init_params["device"] is not None:
|
||||||
|
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
|
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
fixes:
|
||||||
|
- |
|
||||||
|
Updates the from_dict method of SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder, NamedEntityExtractor, SentenceTransformersDiversityRanker and LocalWhisperTranscriber to allow None as a valid value for device when deserializing from a YAML file.
|
||||||
|
This allows a deserialized pipeline to auto-determine what device to use using the ComponentDevice.resolve_device logic.
|
@ -72,6 +72,17 @@ class TestLocalWhisperTranscriber:
|
|||||||
assert transcriber.whisper_params == {}
|
assert transcriber.whisper_params == {}
|
||||||
assert transcriber._model is None
|
assert transcriber._model is None
|
||||||
|
|
||||||
|
def test_from_dict_none_device(self):
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber",
|
||||||
|
"init_parameters": {"model": "tiny", "device": None, "whisper_params": {}},
|
||||||
|
}
|
||||||
|
transcriber = LocalWhisperTranscriber.from_dict(data)
|
||||||
|
assert transcriber.model == "tiny"
|
||||||
|
assert transcriber.device == ComponentDevice.resolve_device(None)
|
||||||
|
assert transcriber.whisper_params == {}
|
||||||
|
assert transcriber._model is None
|
||||||
|
|
||||||
def test_warmup(self):
|
def test_warmup(self):
|
||||||
with patch("haystack.components.audio.whisper_local.whisper") as mocked_whisper:
|
with patch("haystack.components.audio.whisper_local.whisper") as mocked_whisper:
|
||||||
transcriber = LocalWhisperTranscriber(model="large-v2", device=ComponentDevice.from_str("cpu"))
|
transcriber = LocalWhisperTranscriber(model="large-v2", device=ComponentDevice.from_str("cpu"))
|
||||||
|
@ -137,6 +137,38 @@ class TestSentenceTransformersDocumentEmbedder:
|
|||||||
assert component.trust_remote_code
|
assert component.trust_remote_code
|
||||||
assert component.meta_fields_to_embed == ["meta_field"]
|
assert component.meta_fields_to_embed == ["meta_field"]
|
||||||
|
|
||||||
|
def test_from_dict_none_device(self):
|
||||||
|
init_parameters = {
|
||||||
|
"model": "model",
|
||||||
|
"device": None,
|
||||||
|
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||||
|
"prefix": "prefix",
|
||||||
|
"suffix": "suffix",
|
||||||
|
"batch_size": 64,
|
||||||
|
"progress_bar": False,
|
||||||
|
"normalize_embeddings": True,
|
||||||
|
"embedding_separator": " - ",
|
||||||
|
"meta_fields_to_embed": ["meta_field"],
|
||||||
|
"trust_remote_code": True,
|
||||||
|
}
|
||||||
|
component = SentenceTransformersDocumentEmbedder.from_dict(
|
||||||
|
{
|
||||||
|
"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
|
||||||
|
"init_parameters": init_parameters,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert component.model == "model"
|
||||||
|
assert component.device == ComponentDevice.resolve_device(None)
|
||||||
|
assert component.token == Secret.from_env_var("ENV_VAR", strict=False)
|
||||||
|
assert component.prefix == "prefix"
|
||||||
|
assert component.suffix == "suffix"
|
||||||
|
assert component.batch_size == 64
|
||||||
|
assert component.progress_bar is False
|
||||||
|
assert component.normalize_embeddings is True
|
||||||
|
assert component.embedding_separator == " - "
|
||||||
|
assert component.trust_remote_code
|
||||||
|
assert component.meta_fields_to_embed == ["meta_field"]
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||||
)
|
)
|
||||||
|
@ -122,6 +122,32 @@ class TestSentenceTransformersTextEmbedder:
|
|||||||
assert component.normalize_embeddings is False
|
assert component.normalize_embeddings is False
|
||||||
assert component.trust_remote_code is False
|
assert component.trust_remote_code is False
|
||||||
|
|
||||||
|
def test_from_dict_none_device(self):
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
||||||
|
"init_parameters": {
|
||||||
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
|
"model": "model",
|
||||||
|
"device": None,
|
||||||
|
"prefix": "",
|
||||||
|
"suffix": "",
|
||||||
|
"batch_size": 32,
|
||||||
|
"progress_bar": True,
|
||||||
|
"normalize_embeddings": False,
|
||||||
|
"trust_remote_code": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
component = SentenceTransformersTextEmbedder.from_dict(data)
|
||||||
|
assert component.model == "model"
|
||||||
|
assert component.device == ComponentDevice.resolve_device(None)
|
||||||
|
assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||||
|
assert component.prefix == ""
|
||||||
|
assert component.suffix == ""
|
||||||
|
assert component.batch_size == 32
|
||||||
|
assert component.progress_bar is True
|
||||||
|
assert component.normalize_embeddings is False
|
||||||
|
assert component.trust_remote_code is False
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||||
)
|
)
|
||||||
|
@ -40,3 +40,17 @@ def test_named_entity_extractor_serde():
|
|||||||
with pytest.raises(DeserializationError, match=r"Couldn't deserialize"):
|
with pytest.raises(DeserializationError, match=r"Couldn't deserialize"):
|
||||||
serde_data["init_parameters"].pop("backend")
|
serde_data["init_parameters"].pop("backend")
|
||||||
_ = NamedEntityExtractor.from_dict(serde_data)
|
_ = NamedEntityExtractor.from_dict(serde_data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_named_entity_extractor_serde_none_device():
|
||||||
|
extractor = NamedEntityExtractor(
|
||||||
|
backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None
|
||||||
|
)
|
||||||
|
|
||||||
|
serde_data = extractor.to_dict()
|
||||||
|
new_extractor = NamedEntityExtractor.from_dict(serde_data)
|
||||||
|
|
||||||
|
assert type(new_extractor._backend) == type(extractor._backend)
|
||||||
|
assert new_extractor._backend.model_name == extractor._backend.model_name
|
||||||
|
assert new_extractor._backend.device == extractor._backend.device
|
||||||
|
@ -113,6 +113,37 @@ class TestSentenceTransformersDiversityRanker:
|
|||||||
assert ranker.meta_fields_to_embed == []
|
assert ranker.meta_fields_to_embed == []
|
||||||
assert ranker.embedding_separator == "\n"
|
assert ranker.embedding_separator == "\n"
|
||||||
|
|
||||||
|
def test_from_dict_none_device(self):
|
||||||
|
data = {
|
||||||
|
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
|
||||||
|
"init_parameters": {
|
||||||
|
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
"top_k": 10,
|
||||||
|
"device": None,
|
||||||
|
"similarity": "cosine",
|
||||||
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||||
|
"query_prefix": "",
|
||||||
|
"document_prefix": "",
|
||||||
|
"query_suffix": "",
|
||||||
|
"document_suffix": "",
|
||||||
|
"meta_fields_to_embed": [],
|
||||||
|
"embedding_separator": "\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ranker = SentenceTransformersDiversityRanker.from_dict(data)
|
||||||
|
|
||||||
|
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
assert ranker.top_k == 10
|
||||||
|
assert ranker.device == ComponentDevice.resolve_device(None)
|
||||||
|
assert ranker.similarity == "cosine"
|
||||||
|
assert ranker.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||||
|
assert ranker.query_prefix == ""
|
||||||
|
assert ranker.document_prefix == ""
|
||||||
|
assert ranker.query_suffix == ""
|
||||||
|
assert ranker.document_suffix == ""
|
||||||
|
assert ranker.meta_fields_to_embed == []
|
||||||
|
assert ranker.embedding_separator == "\n"
|
||||||
|
|
||||||
def test_to_dict_with_custom_init_parameters(self):
|
def test_to_dict_with_custom_init_parameters(self):
|
||||||
component = SentenceTransformersDiversityRanker(
|
component = SentenceTransformersDiversityRanker(
|
||||||
model="sentence-transformers/msmarco-distilbert-base-v4",
|
model="sentence-transformers/msmarco-distilbert-base-v4",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user