mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-27 10:49:52 +00:00
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
![]() |
import pytest
|
||
|
|
||
|
from haystack.modeling.model.language_model import get_language_model
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"pretrained_model_name_or_path, lm_class",
|
||
|
[
|
||
|
("google/bert_uncased_L-2_H-128_A-2", "HFLanguageModel"),
|
||
|
("google/electra-small-generator", "HFLanguageModelWithPooler"),
|
||
|
("distilbert-base-uncased", "HFLanguageModelNoSegmentIds"),
|
||
|
("deepset/bert-small-mm_retrieval-passage_encoder", "DPREncoder"),
|
||
|
],
|
||
|
)
|
||
|
def test_basic_loading(pretrained_model_name_or_path, lm_class):
|
||
|
lm = get_language_model(pretrained_model_name_or_path)
|
||
|
mod = __import__("haystack.modeling.model.language_model", fromlist=[lm_class])
|
||
|
klass = getattr(mod, lm_class)
|
||
|
assert isinstance(lm, klass)
|
||
|
|
||
|
|
||
|
def test_basic_loading_unknown_model():
|
||
|
with pytest.raises(OSError):
|
||
|
get_language_model("model_that_doesnt_exist")
|
||
|
|
||
|
|
||
|
def test_basic_loading_with_empty_string():
|
||
|
with pytest.raises(ValueError):
|
||
|
get_language_model("")
|
||
|
|
||
|
|
||
|
def test_basic_loading_invalid_params():
|
||
|
with pytest.raises(ValueError):
|
||
|
get_language_model(None)
|