haystack/test/modeling/test_language.py
Sara Zan 4e45062a00
Simplify language_modeling.py and tokenization.py (#2703)
* Simplification of language_model.py and tokenization.py to remove code duplication

Co-authored-by: vblagoje <dovlex@gmail.com>
2022-07-22 16:29:30 +02:00

35 lines
1.1 KiB
Python

import pytest
from haystack.modeling.model.language_model import get_language_model
@pytest.mark.parametrize(
"pretrained_model_name_or_path, lm_class",
[
("google/bert_uncased_L-2_H-128_A-2", "HFLanguageModel"),
("google/electra-small-generator", "HFLanguageModelWithPooler"),
("distilbert-base-uncased", "HFLanguageModelNoSegmentIds"),
("deepset/bert-small-mm_retrieval-passage_encoder", "DPREncoder"),
],
)
def test_basic_loading(pretrained_model_name_or_path, lm_class):
lm = get_language_model(pretrained_model_name_or_path)
mod = __import__("haystack.modeling.model.language_model", fromlist=[lm_class])
klass = getattr(mod, lm_class)
assert isinstance(lm, klass)
def test_basic_loading_unknown_model():
with pytest.raises(OSError):
get_language_model("model_that_doesnt_exist")
def test_basic_loading_with_empty_string():
with pytest.raises(ValueError):
get_language_model("")
def test_basic_loading_invalid_params():
with pytest.raises(ValueError):
get_language_model(None)