2021-09-09 11:54:47 +02:00
|
|
|
import logging
|
|
|
|
|
2021-09-13 18:38:14 +02:00
|
|
|
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
|
|
|
from haystack.modeling.model.language_model import LanguageModel
|
|
|
|
from haystack.modeling.model.prediction_head import QuestionAnsweringHead
|
|
|
|
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
|
2021-09-09 11:54:47 +02:00
|
|
|
|
|
|
|
|
|
|
|
def test_prediction_head_load_save(tmp_path, caplog=None):
|
|
|
|
if caplog:
|
|
|
|
caplog.set_level(logging.CRITICAL)
|
|
|
|
|
|
|
|
set_all_seeds(seed=42)
|
2021-11-09 12:44:20 +01:00
|
|
|
devices, n_gpu = initialize_device_settings(use_cuda=False)
|
2021-09-09 11:54:47 +02:00
|
|
|
lang_model = "bert-base-german-cased"
|
|
|
|
|
|
|
|
language_model = LanguageModel.load(lang_model)
|
|
|
|
prediction_head = QuestionAnsweringHead()
|
|
|
|
|
|
|
|
model = AdaptiveModel(
|
|
|
|
language_model=language_model,
|
|
|
|
prediction_heads=[prediction_head],
|
|
|
|
embeds_dropout_prob=0.1,
|
|
|
|
lm_output_types=["per_sequence"],
|
2022-02-03 13:43:18 +01:00
|
|
|
device=devices[0],
|
|
|
|
)
|
2021-09-09 11:54:47 +02:00
|
|
|
|
|
|
|
model.save(tmp_path)
|
2022-02-03 13:43:18 +01:00
|
|
|
model_loaded = AdaptiveModel.load(tmp_path, device="cpu")
|
2021-09-09 11:54:47 +02:00
|
|
|
assert model_loaded is not None
|