2024-01-09 17:56:20 +01:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from haystack import ComponentError, DeserializationError
|
|
|
|
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
|
2024-01-17 10:41:34 +01:00
|
|
|
from haystack.utils.device import ComponentDevice
|
2024-01-09 17:56:20 +01:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_named_entity_extractor_backend():
|
2024-01-16 15:32:48 +01:00
|
|
|
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
|
2024-01-09 17:56:20 +01:00
|
|
|
|
2024-01-16 15:32:48 +01:00
|
|
|
_ = NamedEntityExtractor(backend="hugging_face", model="dslim/bert-base-NER")
|
2024-01-09 17:56:20 +01:00
|
|
|
|
2024-01-16 15:32:48 +01:00
|
|
|
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_sm")
|
2024-01-09 17:56:20 +01:00
|
|
|
|
2024-01-16 15:32:48 +01:00
|
|
|
_ = NamedEntityExtractor(backend="spacy", model="en_core_web_sm")
|
2024-01-09 17:56:20 +01:00
|
|
|
|
|
|
|
with pytest.raises(ComponentError, match=r"Invalid backend"):
|
2024-01-16 15:32:48 +01:00
|
|
|
NamedEntityExtractor(backend="random_backend", model="dslim/bert-base-NER")
|
2024-01-09 17:56:20 +01:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_named_entity_extractor_serde():
|
|
|
|
extractor = NamedEntityExtractor(
|
2024-01-17 10:41:34 +01:00
|
|
|
backend=NamedEntityExtractorBackend.HUGGING_FACE,
|
|
|
|
model="dslim/bert-base-NER",
|
|
|
|
device=ComponentDevice.from_str("cuda:1"),
|
2024-01-09 17:56:20 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
serde_data = extractor.to_dict()
|
|
|
|
new_extractor = NamedEntityExtractor.from_dict(serde_data)
|
|
|
|
|
|
|
|
assert type(new_extractor._backend) == type(extractor._backend)
|
|
|
|
assert new_extractor._backend.model_name == extractor._backend.model_name
|
2024-01-17 10:41:34 +01:00
|
|
|
assert new_extractor._backend.device == extractor._backend.device
|
2024-01-09 17:56:20 +01:00
|
|
|
|
|
|
|
with pytest.raises(DeserializationError, match=r"Couldn't deserialize"):
|
|
|
|
serde_data["init_parameters"].pop("backend")
|
|
|
|
_ = NamedEntityExtractor.from_dict(serde_data)
|