From a2be90b95a402f71ef158b38523f5249dc94ffef Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Tue, 14 May 2024 08:36:14 +0200 Subject: [PATCH] fix: Update device deserialization for components that use local models (#7686) * fix: Update device deserializtion for SentenceTransformersTextEmbedder * Add unit test * Fix unit test * Make same change to doc embedder * Add release notes * Add same change to Diversity Ranker and Named Entity Extractor * Add unit test * Add the same for whisper local * Update release notes --- haystack/components/audio/whisper_local.py | 6 ++-- ...sentence_transformers_document_embedder.py | 6 ++-- .../sentence_transformers_text_embedder.py | 6 ++-- .../extractors/named_entity_extractor.py | 3 +- .../sentence_transformers_diversity.py | 6 ++-- ...lization-st-embedder-c4efad96dd3869d5.yaml | 5 +++ test/components/audio/test_whisper_local.py | 11 +++++++ ...sentence_transformers_document_embedder.py | 32 +++++++++++++++++++ ...est_sentence_transformers_text_embedder.py | 26 +++++++++++++++ .../extractors/test_named_entity_extractor.py | 14 ++++++++ .../test_sentence_transformers_diversity.py | 31 ++++++++++++++++++ 11 files changed, 133 insertions(+), 13 deletions(-) create mode 100644 releasenotes/notes/fix-device-deserialization-st-embedder-c4efad96dd3869d5.yaml diff --git a/haystack/components/audio/whisper_local.py b/haystack/components/audio/whisper_local.py index f60c76f57..5a96f40e4 100644 --- a/haystack/components/audio/whisper_local.py +++ b/haystack/components/audio/whisper_local.py @@ -90,9 +90,9 @@ class LocalWhisperTranscriber: :returns: The deserialized component. """ - serialized_device = data["init_parameters"]["device"] - data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) - + init_params = data["init_parameters"] + if init_params["device"] is not None: + init_params["device"] = ComponentDevice.from_dict(init_params["device"]) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/haystack/components/embedders/sentence_transformers_document_embedder.py b/haystack/components/embedders/sentence_transformers_document_embedder.py index 010e4938d..40550d572 100644 --- a/haystack/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/components/embedders/sentence_transformers_document_embedder.py @@ -125,9 +125,9 @@ class SentenceTransformersDocumentEmbedder: :returns: Deserialized component. """ - serialized_device = data["init_parameters"]["device"] - data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) - + init_params = data["init_parameters"] + if init_params["device"] is not None: + init_params["device"] = ComponentDevice.from_dict(init_params["device"]) deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) return default_from_dict(cls, data) diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index 5907fbb27..0457f8815 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -115,9 +115,9 @@ class SentenceTransformersTextEmbedder: :returns: Deserialized component. """ - serialized_device = data["init_parameters"]["device"] - data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) - + init_params = data["init_parameters"] + if init_params["device"] is not None: + init_params["device"] = ComponentDevice.from_dict(init_params["device"]) deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) return default_from_dict(cls, data) diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index 649b7dc59..a8c6c15bc 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -215,7 +215,8 @@ class NamedEntityExtractor: """ try: init_params = data["init_parameters"] - init_params["device"] = ComponentDevice.from_dict(init_params["device"]) + if init_params["device"] is not None: + init_params["device"] = ComponentDevice.from_dict(init_params["device"]) return default_from_dict(cls, data) except Exception as e: raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py index 43f5a2417..c1a216533 100644 --- a/haystack/components/rankers/sentence_transformers_diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -141,9 +141,9 @@ class SentenceTransformersDiversityRanker: :returns: The deserialized component. """ - serialized_device = data["init_parameters"]["device"] - data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) - + init_params = data["init_parameters"] + if init_params["device"] is not None: + init_params["device"] = ComponentDevice.from_dict(init_params["device"]) deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) return default_from_dict(cls, data) diff --git a/releasenotes/notes/fix-device-deserialization-st-embedder-c4efad96dd3869d5.yaml b/releasenotes/notes/fix-device-deserialization-st-embedder-c4efad96dd3869d5.yaml new file mode 100644 index 000000000..6bb0a4d2b --- /dev/null +++ b/releasenotes/notes/fix-device-deserialization-st-embedder-c4efad96dd3869d5.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Updates the from_dict method of SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder, NamedEntityExtractor, SentenceTransformersDiversityRanker and LocalWhisperTranscriber to allow None as a valid value for device when deserializing from a YAML file. + This allows a deserialized pipeline to auto-determine what device to use using the ComponentDevice.resolve_device logic. diff --git a/test/components/audio/test_whisper_local.py b/test/components/audio/test_whisper_local.py index 6a6c3a8f2..6cbd43575 100644 --- a/test/components/audio/test_whisper_local.py +++ b/test/components/audio/test_whisper_local.py @@ -72,6 +72,17 @@ class TestLocalWhisperTranscriber: assert transcriber.whisper_params == {} assert transcriber._model is None + def test_from_dict_none_device(self): + data = { + "type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber", + "init_parameters": {"model": "tiny", "device": None, "whisper_params": {}}, + } + transcriber = LocalWhisperTranscriber.from_dict(data) + assert transcriber.model == "tiny" + assert transcriber.device == ComponentDevice.resolve_device(None) + assert transcriber.whisper_params == {} + assert transcriber._model is None + def test_warmup(self): with patch("haystack.components.audio.whisper_local.whisper") as mocked_whisper: transcriber = LocalWhisperTranscriber(model="large-v2", device=ComponentDevice.from_str("cpu")) diff --git a/test/components/embedders/test_sentence_transformers_document_embedder.py b/test/components/embedders/test_sentence_transformers_document_embedder.py index e9fc3e3c6..75564188a 100644 --- a/test/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/components/embedders/test_sentence_transformers_document_embedder.py @@ -137,6 +137,38 @@ class TestSentenceTransformersDocumentEmbedder: assert component.trust_remote_code assert component.meta_fields_to_embed == ["meta_field"] + def test_from_dict_none_device(self): + init_parameters = { + "model": "model", + "device": None, + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "prefix": "prefix", + "suffix": "suffix", + "batch_size": 64, + "progress_bar": False, + "normalize_embeddings": True, + "embedding_separator": " - ", + "meta_fields_to_embed": ["meta_field"], + "trust_remote_code": True, + } + component = SentenceTransformersDocumentEmbedder.from_dict( + { + "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", + "init_parameters": init_parameters, + } + ) + assert component.model == "model" + assert component.device == ComponentDevice.resolve_device(None) + assert component.token == Secret.from_env_var("ENV_VAR", strict=False) + assert component.prefix == "prefix" + assert component.suffix == "suffix" + assert component.batch_size == 64 + assert component.progress_bar is False + assert component.normalize_embeddings is True + assert component.embedding_separator == " - " + assert component.trust_remote_code + assert component.meta_fields_to_embed == ["meta_field"] + @patch( "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" ) diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 433a51252..ec9234b6c 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -122,6 +122,32 @@ class TestSentenceTransformersTextEmbedder: assert component.normalize_embeddings is False assert component.trust_remote_code is False + def test_from_dict_none_device(self): + data = { + "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", + "init_parameters": { + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "model": "model", + "device": None, + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "normalize_embeddings": False, + "trust_remote_code": False, + }, + } + component = SentenceTransformersTextEmbedder.from_dict(data) + assert component.model == "model" + assert component.device == ComponentDevice.resolve_device(None) + assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert component.prefix == "" + assert component.suffix == "" + assert component.batch_size == 32 + assert component.progress_bar is True + assert component.normalize_embeddings is False + assert component.trust_remote_code is False + @patch( "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" ) diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index 140752f26..d47ae69dd 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -40,3 +40,17 @@ def test_named_entity_extractor_serde(): with pytest.raises(DeserializationError, match=r"Couldn't deserialize"): serde_data["init_parameters"].pop("backend") _ = NamedEntityExtractor.from_dict(serde_data) + + +@pytest.mark.unit +def test_named_entity_extractor_serde_none_device(): + extractor = NamedEntityExtractor( + backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None + ) + + serde_data = extractor.to_dict() + new_extractor = NamedEntityExtractor.from_dict(serde_data) + + assert type(new_extractor._backend) == type(extractor._backend) + assert new_extractor._backend.model_name == extractor._backend.model_name + assert new_extractor._backend.device == extractor._backend.device diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index a7fade57f..b4885d327 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -113,6 +113,37 @@ class TestSentenceTransformersDiversityRanker: assert ranker.meta_fields_to_embed == [] assert ranker.embedding_separator == "\n" + def test_from_dict_none_device(self): + data = { + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", + "init_parameters": { + "model": "sentence-transformers/all-MiniLM-L6-v2", + "top_k": 10, + "device": None, + "similarity": "cosine", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "query_prefix": "", + "document_prefix": "", + "query_suffix": "", + "document_suffix": "", + "meta_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + ranker = SentenceTransformersDiversityRanker.from_dict(data) + + assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" + assert ranker.top_k == 10 + assert ranker.device == ComponentDevice.resolve_device(None) + assert ranker.similarity == "cosine" + assert ranker.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert ranker.query_prefix == "" + assert ranker.document_prefix == "" + assert ranker.query_suffix == "" + assert ranker.document_suffix == "" + assert ranker.meta_fields_to_embed == [] + assert ranker.embedding_separator == "\n" + def test_to_dict_with_custom_init_parameters(self): component = SentenceTransformersDiversityRanker( model="sentence-transformers/msmarco-distilbert-base-v4",