fix: Fix NamedEntityExtractor serde (#7684)

* Fix NamedEntityExtractor serde

* Add release note

* Linting, remove unit markers
This commit is contained in:
Vladimir Blagojevic 2024-05-14 12:24:55 +02:00 committed by GitHub
parent 75cf35c743
commit 4352b1688e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 5 deletions

View File

@ -197,7 +197,7 @@ class NamedEntityExtractor:
""" """
return default_to_dict( return default_to_dict(
self, self,
backend=self._backend.type, backend=self._backend.type.name,
model=self._backend.model_name, model=self._backend.model_name,
device=self._backend.device.to_dict(), device=self._backend.device.to_dict(),
pipeline_kwargs=self._backend._pipeline_kwargs, pipeline_kwargs=self._backend._pipeline_kwargs,
@ -217,6 +217,7 @@ class NamedEntityExtractor:
init_params = data["init_parameters"] init_params = data["init_parameters"]
if init_params["device"] is not None: if init_params["device"] is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"]) init_params["device"] = ComponentDevice.from_dict(init_params["device"])
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]
return default_from_dict(cls, data) return default_from_dict(cls, data)
except Exception as e: except Exception as e:
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fixed (de)serialization of NamedEntityExtractor. Includes updated tests verifying these fixes when NamedEntityExtractor is used in pipelines.

View File

@ -3,12 +3,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest import pytest
from haystack import ComponentError, DeserializationError from haystack import ComponentError, DeserializationError, Pipeline
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
from haystack.utils.device import ComponentDevice from haystack.utils.device import ComponentDevice
@pytest.mark.unit
def test_named_entity_extractor_backend(): def test_named_entity_extractor_backend():
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER") _ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
@ -22,7 +21,6 @@ def test_named_entity_extractor_backend():
NamedEntityExtractor(backend="random_backend", model="dslim/bert-base-NER") NamedEntityExtractor(backend="random_backend", model="dslim/bert-base-NER")
@pytest.mark.unit
def test_named_entity_extractor_serde(): def test_named_entity_extractor_serde():
extractor = NamedEntityExtractor( extractor = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE, backend=NamedEntityExtractorBackend.HUGGING_FACE,
@ -42,7 +40,20 @@ def test_named_entity_extractor_serde():
_ = NamedEntityExtractor.from_dict(serde_data) _ = NamedEntityExtractor.from_dict(serde_data)
@pytest.mark.unit # tests for NamedEntityExtractor serialization/deserialization in a pipeline
def test_named_entity_extractor_pipeline_serde(tmp_path):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
p = Pipeline()
p.add_component(instance=extractor, name="extractor")
with open(tmp_path / "test_pipeline.yaml", "w") as f:
p.dump(f)
with open(tmp_path / "test_pipeline.yaml", "r") as f:
q = Pipeline.load(f)
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with NamedEntityExtractor failed."
def test_named_entity_extractor_serde_none_device(): def test_named_entity_extractor_serde_none_device():
extractor = NamedEntityExtractor( extractor = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None