mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-10 06:43:58 +00:00
fix: Fix NamedEntityExtractor serde (#7684)
* Fix NamedEntityExtractor serde * Add release note * Linting, remove unit markers
This commit is contained in:
parent
75cf35c743
commit
4352b1688e
@ -197,7 +197,7 @@ class NamedEntityExtractor:
|
|||||||
"""
|
"""
|
||||||
return default_to_dict(
|
return default_to_dict(
|
||||||
self,
|
self,
|
||||||
backend=self._backend.type,
|
backend=self._backend.type.name,
|
||||||
model=self._backend.model_name,
|
model=self._backend.model_name,
|
||||||
device=self._backend.device.to_dict(),
|
device=self._backend.device.to_dict(),
|
||||||
pipeline_kwargs=self._backend._pipeline_kwargs,
|
pipeline_kwargs=self._backend._pipeline_kwargs,
|
||||||
@ -217,6 +217,7 @@ class NamedEntityExtractor:
|
|||||||
init_params = data["init_parameters"]
|
init_params = data["init_parameters"]
|
||||||
if init_params["device"] is not None:
|
if init_params["device"] is not None:
|
||||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||||
|
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]
|
||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
|
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
fixes:
|
||||||
|
- |
|
||||||
|
Fixed (de)serialization of NamedEntityExtractor. Includes updated tests verifying these fixes when NamedEntityExtractor is used in pipelines.
|
||||||
@ -3,12 +3,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack import ComponentError, DeserializationError
|
from haystack import ComponentError, DeserializationError, Pipeline
|
||||||
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
|
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
|
||||||
from haystack.utils.device import ComponentDevice
|
from haystack.utils.device import ComponentDevice
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
def test_named_entity_extractor_backend():
|
def test_named_entity_extractor_backend():
|
||||||
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
|
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
|
||||||
|
|
||||||
@ -22,7 +21,6 @@ def test_named_entity_extractor_backend():
|
|||||||
NamedEntityExtractor(backend="random_backend", model="dslim/bert-base-NER")
|
NamedEntityExtractor(backend="random_backend", model="dslim/bert-base-NER")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
def test_named_entity_extractor_serde():
|
def test_named_entity_extractor_serde():
|
||||||
extractor = NamedEntityExtractor(
|
extractor = NamedEntityExtractor(
|
||||||
backend=NamedEntityExtractorBackend.HUGGING_FACE,
|
backend=NamedEntityExtractorBackend.HUGGING_FACE,
|
||||||
@ -42,7 +40,20 @@ def test_named_entity_extractor_serde():
|
|||||||
_ = NamedEntityExtractor.from_dict(serde_data)
|
_ = NamedEntityExtractor.from_dict(serde_data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
# tests for NamedEntityExtractor serialization/deserialization in a pipeline
|
||||||
|
def test_named_entity_extractor_pipeline_serde(tmp_path):
|
||||||
|
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
|
||||||
|
p = Pipeline()
|
||||||
|
p.add_component(instance=extractor, name="extractor")
|
||||||
|
|
||||||
|
with open(tmp_path / "test_pipeline.yaml", "w") as f:
|
||||||
|
p.dump(f)
|
||||||
|
with open(tmp_path / "test_pipeline.yaml", "r") as f:
|
||||||
|
q = Pipeline.load(f)
|
||||||
|
|
||||||
|
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with NamedEntityExtractor failed."
|
||||||
|
|
||||||
|
|
||||||
def test_named_entity_extractor_serde_none_device():
|
def test_named_entity_extractor_serde_none_device():
|
||||||
extractor = NamedEntityExtractor(
|
extractor = NamedEntityExtractor(
|
||||||
backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None
|
backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device=None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user