mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-28 11:19:58 +00:00

* Add type annotations in QuestionAnsweringHead * Fix test by increasing max_seq_len * Add SampleBasket type annotation * Remove prediction head param from adaptive model init * Add type ignore for AdaptiveModel init * Fix and rename tests * Adjust folder structure Co-authored-by: Julian Risch <julian.risch@deepset.ai>
30 lines
982 B
Python
30 lines
982 B
Python
import logging
|
|
|
|
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
|
from haystack.modeling.model.language_model import LanguageModel
|
|
from haystack.modeling.model.prediction_head import QuestionAnsweringHead
|
|
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
|
|
|
|
|
|
def test_prediction_head_load_save(tmp_path, caplog=None):
|
|
if caplog:
|
|
caplog.set_level(logging.CRITICAL)
|
|
|
|
set_all_seeds(seed=42)
|
|
device, n_gpu = initialize_device_settings(use_cuda=False)
|
|
lang_model = "bert-base-german-cased"
|
|
|
|
language_model = LanguageModel.load(lang_model)
|
|
prediction_head = QuestionAnsweringHead()
|
|
|
|
model = AdaptiveModel(
|
|
language_model=language_model,
|
|
prediction_heads=[prediction_head],
|
|
embeds_dropout_prob=0.1,
|
|
lm_output_types=["per_sequence"],
|
|
device=device)
|
|
|
|
model.save(tmp_path)
|
|
model_loaded = AdaptiveModel.load(tmp_path, device='cpu')
|
|
assert model_loaded is not None
|