haystack/test/modeling/test_model_loading.py
Grant Williams 1cf70d3dce
build: Upgrade transformers to the latest version 4.34.1 (#5994)
* Upgrade transformers to the latest version 4.34.0 so that Haystack can support the new Mistral, Nougat, and other models.

* update release notes

* updated missing lazy import

* Update .github workflows imports

* bump more versions in .github workflows

* rever import sorting

* Update  to catch runtime errors to match haystack_hub changes

* add language parameter value to whisper test

* bump transformers version in linting preview workflow

* bump transformers version in linting preview workflow

* bump version to v4.34.1

* resolve mypy issue with reused variables

* install openai-whisper without dependencies

* remove audio extra, update whisper install instructions

* remove audio extra, update whisper install instructions

* keep audio extra but add version

* keep audio extra with no constraints

* remove audio extra

---------

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
2023-10-24 19:13:12 +02:00

44 lines
1.2 KiB
Python

import pytest
from haystack.modeling.model.language_model import (
get_language_model,
HFLanguageModel,
HFLanguageModelNoSegmentIds,
HFLanguageModelWithPooler,
DPREncoder,
)
@pytest.mark.integration
@pytest.mark.parametrize(
"pretrained_model_name_or_path, lm_class",
[
("google/bert_uncased_L-2_H-128_A-2", HFLanguageModel),
("google/electra-small-generator", HFLanguageModelWithPooler),
("distilbert-base-uncased", HFLanguageModelNoSegmentIds),
("deepset/bert-small-mm_retrieval-passage_encoder", DPREncoder),
],
)
def test_basic_loading(pretrained_model_name_or_path, lm_class, monkeypatch):
monkeypatch.setattr(lm_class, "__init__", lambda self, *a, **k: None)
lm = get_language_model(pretrained_model_name_or_path)
assert isinstance(lm, lm_class)
@pytest.mark.unit
def test_basic_loading_unknown_model():
with pytest.raises(RuntimeError):
get_language_model("model_that_doesnt_exist")
@pytest.mark.unit
def test_basic_loading_with_empty_string():
with pytest.raises(ValueError):
get_language_model("")
@pytest.mark.unit
def test_basic_loading_invalid_params():
with pytest.raises(ValueError):
get_language_model(None)