haystack/test/modeling/test_model_loading.py
ZanSara b60d9a2cbf
test: move several modeling tests in e2e/ (#4308)
* no dpr test seems worth mocking

* move distillation tests

* pylint

* mypy

* pylint

* move feature_extraction tests as well

* move feature_extraction tests as well

* merge feature extractor suites

* get_language_model tests and adaptive model tests

* duplicate test

* moving fixtures

* mypy

* mypy-again

* trigger

* un-mock integration test

* review feedback

* feedback

* pylint
2023-04-28 17:08:41 +02:00

44 lines
1.2 KiB
Python

import pytest
from haystack.modeling.model.language_model import (
get_language_model,
HFLanguageModel,
HFLanguageModelNoSegmentIds,
HFLanguageModelWithPooler,
DPREncoder,
)
@pytest.mark.unit
@pytest.mark.parametrize(
"pretrained_model_name_or_path, lm_class",
[
("google/bert_uncased_L-2_H-128_A-2", HFLanguageModel),
("google/electra-small-generator", HFLanguageModelWithPooler),
("distilbert-base-uncased", HFLanguageModelNoSegmentIds),
("deepset/bert-small-mm_retrieval-passage_encoder", DPREncoder),
],
)
def test_basic_loading(pretrained_model_name_or_path, lm_class, monkeypatch):
monkeypatch.setattr(lm_class, "__init__", lambda self, *a, **k: None)
lm = get_language_model(pretrained_model_name_or_path)
assert isinstance(lm, lm_class)
@pytest.mark.unit
def test_basic_loading_unknown_model():
with pytest.raises(OSError):
get_language_model("model_that_doesnt_exist")
@pytest.mark.unit
def test_basic_loading_with_empty_string():
with pytest.raises(ValueError):
get_language_model("")
@pytest.mark.unit
def test_basic_loading_invalid_params():
with pytest.raises(ValueError):
get_language_model(None)