From 42ca057cfdfcf4d84e4548b7dde4f60d8d3e126e Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:53:48 +0100 Subject: [PATCH] fix: Fix deserialization of `NamedEntityExtractor` when `pipeline_kwargs` has value of `None` (#10292) * Fix sede bug in NamedEntityExtractor * Add reno * Use correct backticks --- .../extractors/named_entity_extractor.py | 4 +-- ...med-entity-extractor-84dcdde666b86cdf.yaml | 4 +++ .../extractors/test_named_entity_extractor.py | 36 +++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 releasenotes/notes/fix-sede-bug-named-entity-extractor-84dcdde666b86cdf.yaml diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index 341605c0e..dee67e749 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -249,8 +249,8 @@ class NamedEntityExtractor: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]] - hf_pipeline_kwargs = init_params.get("pipeline_kwargs", {}) - deserialize_hf_model_kwargs(hf_pipeline_kwargs) + hf_pipeline_kwargs = init_params.get("pipeline_kwargs") + deserialize_hf_model_kwargs(hf_pipeline_kwargs or {}) return default_from_dict(cls, data) except Exception as e: raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e diff --git a/releasenotes/notes/fix-sede-bug-named-entity-extractor-84dcdde666b86cdf.yaml b/releasenotes/notes/fix-sede-bug-named-entity-extractor-84dcdde666b86cdf.yaml new file mode 100644 index 000000000..f0b0459b6 --- /dev/null +++ b/releasenotes/notes/fix-sede-bug-named-entity-extractor-84dcdde666b86cdf.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixes deserializing an instance of ``NamedEntityExtractor`` when ``pipeline_kwargs`` is stored in the deserialization dict with the value of ``None``. diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index ad8b44413..2cb815c2d 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -28,6 +28,42 @@ def test_named_entity_extractor_backend(): NamedEntityExtractor(backend="random_backend", model="dslim/bert-base-NER") +def test_named_entity_extractor_to_dict(): + extractor = NamedEntityExtractor( + backend=NamedEntityExtractorBackend.HUGGING_FACE, + model="dslim/bert-base-NER", + device=ComponentDevice.from_str("cuda:1"), + ) + + serde_data = extractor.to_dict() + assert serde_data == { + "type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor", + "init_parameters": { + "backend": "HUGGING_FACE", + "model": "dslim/bert-base-NER", + "device": {"type": "single", "device": "cuda:1"}, + "pipeline_kwargs": {"model": "dslim/bert-base-NER", "device": "cuda:1", "task": "ner"}, + "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False}, + }, + } + + +def test_named_entity_extractor_from_dict(): + data = { + "type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor", + "init_parameters": { + "backend": "HUGGING_FACE", + "model": "dslim/bert-base-NER", + "device": None, + "pipeline_kwargs": None, + "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False}, + }, + } + extractor = NamedEntityExtractor.from_dict(data) + + assert extractor._backend.model_name == "dslim/bert-base-NER" + + def test_named_entity_extractor_serde(): extractor = NamedEntityExtractor( backend=NamedEntityExtractorBackend.HUGGING_FACE,