mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-27 02:40:41 +00:00
30 lines
983 B
Python
30 lines
983 B
Python
![]() |
import logging
|
||
|
|
||
|
from haystack.basics.modeling.adaptive_model import AdaptiveModel
|
||
|
from haystack.basics.modeling.language_model import LanguageModel
|
||
|
from haystack.basics.modeling.prediction_head import QuestionAnsweringHead
|
||
|
from haystack.basics.utils import set_all_seeds, initialize_device_settings
|
||
|
|
||
|
|
||
|
def test_prediction_head_load_save(tmp_path, caplog=None):
|
||
|
if caplog:
|
||
|
caplog.set_level(logging.CRITICAL)
|
||
|
|
||
|
set_all_seeds(seed=42)
|
||
|
device, n_gpu = initialize_device_settings(use_cuda=False)
|
||
|
lang_model = "bert-base-german-cased"
|
||
|
|
||
|
language_model = LanguageModel.load(lang_model)
|
||
|
prediction_head = QuestionAnsweringHead()
|
||
|
|
||
|
model = AdaptiveModel(
|
||
|
language_model=language_model,
|
||
|
prediction_heads=[prediction_head],
|
||
|
embeds_dropout_prob=0.1,
|
||
|
lm_output_types=["per_sequence"],
|
||
|
device=device)
|
||
|
|
||
|
model.save(tmp_path)
|
||
|
model_loaded = AdaptiveModel.load(tmp_path, device='cpu')
|
||
|
assert model_loaded is not None
|