mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-07 21:33:39 +00:00
fix: Fix NamedEntityExtractor serde (#7684)
* Fix NamedEntityExtractor serde * Add release note * Linting, remove unit markers
This commit is contained in:
parent
75cf35c743
commit
4352b1688e
@ -197,7 +197,7 @@ class NamedEntityExtractor:
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
backend=self._backend.type,
|
||||
backend=self._backend.type.name,
|
||||
model=self._backend.model_name,
|
||||
device=self._backend.device.to_dict(),
|
||||
pipeline_kwargs=self._backend._pipeline_kwargs,
|
||||
@ -217,6 +217,7 @@ class NamedEntityExtractor:
|
||||
init_params = data["init_parameters"]
|
||||
if init_params["device"] is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]
|
||||
return default_from_dict(cls, data)
|
||||
except Exception as e:
|
||||
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Fixed (de)serialization of NamedEntityExtractor. Includes updated tests verifying these fixes when NamedEntityExtractor is used in pipelines.
|
||||
@ -3,12 +3,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
|
||||
from haystack import ComponentError, DeserializationError
|
||||
from haystack import ComponentError, DeserializationError, Pipeline
|
||||
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
|
||||
from haystack.utils.device import ComponentDevice
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_named_entity_extractor_backend():
|
||||
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
|
||||
|
||||
@ -22,7 +21,6 @@ def test_named_entity_extractor_backend():
|
||||
NamedEntityExtractor(backend="random_backend", model="dslim/bert-base-NER")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_named_entity_extractor_serde():
|
||||
extractor = NamedEntityExtractor(
|
||||
backend=NamedEntityExtractorBackend.HUGGING_FACE,
|
||||
@ -42,7 +40,20 @@ def test_named_entity_extractor_serde():
|
||||
_ = NamedEntityExtractor.from_dict(serde_data)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
# tests for NamedEntityExtractor serialization/deserialization in a pipeline
|
||||
def test_named_entity_extractor_pipeline_serde(tmp_path):
|
||||
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
|
||||
p = Pipeline()
|
||||
p.add_component(instance=extractor, name="extractor")
|
||||
|
||||
with open(tmp_path / "test_pipeline.yaml", "w") as f:
|
||||
p.dump(f)
|
||||
with open(tmp_path / "test_pipeline.yaml", "r") as f:
|
||||
q = Pipeline.load(f)
|
||||
|
||||
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with NamedEntityExtractor failed."
|
||||
|
||||
|
||||
def test_named_entity_extractor_serde_none_device():
|
||||
extractor = NamedEntityExtractor(
|
||||
backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user