mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-27 02:40:41 +00:00

* Upgrade transformers to the latest version 4.34.0 so that Haystack can support the new Mistral, Nougat, and other models. * update release notes * updated missing lazy import * Update .github workflows imports * bump more versions in .github workflows * rever import sorting * Update to catch runtime errors to match haystack_hub changes * add language parameter value to whisper test * bump transformers version in linting preview workflow * bump transformers version in linting preview workflow * bump version to v4.34.1 * resolve mypy issue with reused variables * install openai-whisper without dependencies * remove audio extra, update whisper install instructions * remove audio extra, update whisper install instructions * keep audio extra but add version * keep audio extra with no constraints * remove audio extra --------- Co-authored-by: Julian Risch <julian.risch@deepset.ai>
44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
import pytest
|
|
|
|
from haystack.modeling.model.language_model import (
|
|
get_language_model,
|
|
HFLanguageModel,
|
|
HFLanguageModelNoSegmentIds,
|
|
HFLanguageModelWithPooler,
|
|
DPREncoder,
|
|
)
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize(
|
|
"pretrained_model_name_or_path, lm_class",
|
|
[
|
|
("google/bert_uncased_L-2_H-128_A-2", HFLanguageModel),
|
|
("google/electra-small-generator", HFLanguageModelWithPooler),
|
|
("distilbert-base-uncased", HFLanguageModelNoSegmentIds),
|
|
("deepset/bert-small-mm_retrieval-passage_encoder", DPREncoder),
|
|
],
|
|
)
|
|
def test_basic_loading(pretrained_model_name_or_path, lm_class, monkeypatch):
|
|
monkeypatch.setattr(lm_class, "__init__", lambda self, *a, **k: None)
|
|
lm = get_language_model(pretrained_model_name_or_path)
|
|
assert isinstance(lm, lm_class)
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_basic_loading_unknown_model():
|
|
with pytest.raises(RuntimeError):
|
|
get_language_model("model_that_doesnt_exist")
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_basic_loading_with_empty_string():
|
|
with pytest.raises(ValueError):
|
|
get_language_model("")
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_basic_loading_invalid_params():
|
|
with pytest.raises(ValueError):
|
|
get_language_model(None)
|