mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-02-06 23:12:43 +00:00
fix: Fix deserialization of NamedEntityExtractor when pipeline_kwargs has value of None (#10292)
* Fix sede bug in NamedEntityExtractor * Add reno * Use correct backticks
This commit is contained in:
parent
335d2f6d8b
commit
42ca057cfd
@ -249,8 +249,8 @@ class NamedEntityExtractor:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]
|
||||
|
||||
hf_pipeline_kwargs = init_params.get("pipeline_kwargs", {})
|
||||
deserialize_hf_model_kwargs(hf_pipeline_kwargs)
|
||||
hf_pipeline_kwargs = init_params.get("pipeline_kwargs")
|
||||
deserialize_hf_model_kwargs(hf_pipeline_kwargs or {})
|
||||
return default_from_dict(cls, data)
|
||||
except Exception as e:
|
||||
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Fixes deserializing an instance of ``NamedEntityExtractor`` when ``pipeline_kwargs`` is stored in the deserialization dict with the value of ``None``.
|
||||
@ -28,6 +28,42 @@ def test_named_entity_extractor_backend():
|
||||
NamedEntityExtractor(backend="random_backend", model="dslim/bert-base-NER")
|
||||
|
||||
|
||||
def test_named_entity_extractor_to_dict():
|
||||
extractor = NamedEntityExtractor(
|
||||
backend=NamedEntityExtractorBackend.HUGGING_FACE,
|
||||
model="dslim/bert-base-NER",
|
||||
device=ComponentDevice.from_str("cuda:1"),
|
||||
)
|
||||
|
||||
serde_data = extractor.to_dict()
|
||||
assert serde_data == {
|
||||
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
|
||||
"init_parameters": {
|
||||
"backend": "HUGGING_FACE",
|
||||
"model": "dslim/bert-base-NER",
|
||||
"device": {"type": "single", "device": "cuda:1"},
|
||||
"pipeline_kwargs": {"model": "dslim/bert-base-NER", "device": "cuda:1", "task": "ner"},
|
||||
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_named_entity_extractor_from_dict():
|
||||
data = {
|
||||
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
|
||||
"init_parameters": {
|
||||
"backend": "HUGGING_FACE",
|
||||
"model": "dslim/bert-base-NER",
|
||||
"device": None,
|
||||
"pipeline_kwargs": None,
|
||||
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
|
||||
},
|
||||
}
|
||||
extractor = NamedEntityExtractor.from_dict(data)
|
||||
|
||||
assert extractor._backend.model_name == "dslim/bert-base-NER"
|
||||
|
||||
|
||||
def test_named_entity_extractor_serde():
|
||||
extractor = NamedEntityExtractor(
|
||||
backend=NamedEntityExtractorBackend.HUGGING_FACE,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user