From 4352b1688e38bd60f9967ee6f7ad101aed65d4e0 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 14 May 2024 12:24:55 +0200 Subject: [PATCH] fix: Fix NamedEntityExtractor serde (#7684) * Fix NamedEntityExtractor serde * Add release note * Linting, remove unit markers --- .../extractors/named_entity_extractor.py | 3 ++- ...r-serde-improvements-28b594be5a38f175.yaml | 4 ++++ .../extractors/test_named_entity_extractor.py | 19 +++++++++++++++---- 3 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 releasenotes/notes/named-entity-extractor-serde-improvements-28b594be5a38f175.yaml diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index a8c6c15bc..93fe20ece 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -197,7 +197,7 @@ class NamedEntityExtractor: """ return default_to_dict( self, - backend=self._backend.type, + backend=self._backend.type.name, model=self._backend.model_name, device=self._backend.device.to_dict(), pipeline_kwargs=self._backend._pipeline_kwargs, @@ -217,6 +217,7 @@ class NamedEntityExtractor: init_params = data["init_parameters"] if init_params["device"] is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) + init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]] return default_from_dict(cls, data) except Exception as e: raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e diff --git a/releasenotes/notes/named-entity-extractor-serde-improvements-28b594be5a38f175.yaml b/releasenotes/notes/named-entity-extractor-serde-improvements-28b594be5a38f175.yaml new file mode 100644 index 000000000..af37a9c36 --- /dev/null +++ b/releasenotes/notes/named-entity-extractor-serde-improvements-28b594be5a38f175.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixed (de)serialization of NamedEntityExtractor. Includes updated tests verifying these fixes when NamedEntityExtractor is used in pipelines. diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index d47ae69dd..180bee41e 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -3,12 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from haystack import ComponentError, DeserializationError +from haystack import ComponentError, DeserializationError, Pipeline from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend from haystack.utils.device import ComponentDevice -@pytest.mark.unit def test_named_entity_extractor_backend(): _ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER") @@ -22,7 +21,6 @@ def test_named_entity_extractor_backend(): NamedEntityExtractor(backend="random_backend", model="dslim/bert-base-NER") -@pytest.mark.unit def test_named_entity_extractor_serde(): extractor = NamedEntityExtractor( backend=NamedEntityExtractorBackend.HUGGING_FACE, @@ -42,7 +40,20 @@ def test_named_entity_extractor_serde(): _ = NamedEntityExtractor.from_dict(serde_data) -@pytest.mark.unit +# tests for NamedEntityExtractor serialization/deserialization in a pipeline +def test_named_entity_extractor_pipeline_serde(tmp_path): + extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER") + p = Pipeline() + p.add_component(instance=extractor, name="extractor") + + with open(tmp_path / "test_pipeline.yaml", "w") as f: + p.dump(f) + with open(tmp_path / "test_pipeline.yaml", "r") as f: + q = Pipeline.load(f) + + assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with NamedEntityExtractor failed." + + def test_named_entity_extractor_serde_none_device(): extractor = NamedEntityExtractor( backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None