fix: Fix NamedEntityExtractor serde (#7684)

* Fix NamedEntityExtractor serde

* Add release note

* Linting, remove unit markers
This commit is contained in:
Vladimir Blagojevic 2024-05-14 12:24:55 +02:00 committed by GitHub
parent 75cf35c743
commit 4352b1688e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 5 deletions

View File

@ -197,7 +197,7 @@ class NamedEntityExtractor:
"""
return default_to_dict(
self,
backend=self._backend.type,
backend=self._backend.type.name,
model=self._backend.model_name,
device=self._backend.device.to_dict(),
pipeline_kwargs=self._backend._pipeline_kwargs,
@ -217,6 +217,7 @@ class NamedEntityExtractor:
init_params = data["init_parameters"]
if init_params["device"] is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]
return default_from_dict(cls, data)
except Exception as e:
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fixed (de)serialization of NamedEntityExtractor. Includes updated tests verifying these fixes when NamedEntityExtractor is used in pipelines.

View File

@ -3,12 +3,11 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from haystack import ComponentError, DeserializationError
from haystack import ComponentError, DeserializationError, Pipeline
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
from haystack.utils.device import ComponentDevice
@pytest.mark.unit
def test_named_entity_extractor_backend():
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
@ -22,7 +21,6 @@ def test_named_entity_extractor_backend():
NamedEntityExtractor(backend="random_backend", model="dslim/bert-base-NER")
@pytest.mark.unit
def test_named_entity_extractor_serde():
extractor = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE,
@ -42,7 +40,20 @@ def test_named_entity_extractor_serde():
_ = NamedEntityExtractor.from_dict(serde_data)
@pytest.mark.unit
# tests for NamedEntityExtractor serialization/deserialization in a pipeline
def test_named_entity_extractor_pipeline_serde(tmp_path):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
p = Pipeline()
p.add_component(instance=extractor, name="extractor")
with open(tmp_path / "test_pipeline.yaml", "w") as f:
p.dump(f)
with open(tmp_path / "test_pipeline.yaml", "r") as f:
q = Pipeline.load(f)
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with NamedEntityExtractor failed."
def test_named_entity_extractor_serde_none_device():
extractor = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None