haystack/test/modeling/test_model_loading.py
Sebastian 707f1c3546
Add modeling to unit tests so it we can get coverage for that (#4809)
* Add modeling to unit tests so it we can get coverage for that

* fix unit tests

---------

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
2023-05-08 19:05:21 +02:00

44 lines
1.2 KiB
Python

import pytest
from haystack.modeling.model.language_model import (
get_language_model,
HFLanguageModel,
HFLanguageModelNoSegmentIds,
HFLanguageModelWithPooler,
DPREncoder,
)
@pytest.mark.integration
@pytest.mark.parametrize(
"pretrained_model_name_or_path, lm_class",
[
("google/bert_uncased_L-2_H-128_A-2", HFLanguageModel),
("google/electra-small-generator", HFLanguageModelWithPooler),
("distilbert-base-uncased", HFLanguageModelNoSegmentIds),
("deepset/bert-small-mm_retrieval-passage_encoder", DPREncoder),
],
)
def test_basic_loading(pretrained_model_name_or_path, lm_class, monkeypatch):
monkeypatch.setattr(lm_class, "__init__", lambda self, *a, **k: None)
lm = get_language_model(pretrained_model_name_or_path)
assert isinstance(lm, lm_class)
@pytest.mark.unit
def test_basic_loading_unknown_model():
with pytest.raises(OSError):
get_language_model("model_that_doesnt_exist")
@pytest.mark.unit
def test_basic_loading_with_empty_string():
with pytest.raises(ValueError):
get_language_model("")
@pytest.mark.unit
def test_basic_loading_invalid_params():
with pytest.raises(ValueError):
get_language_model(None)