mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-24 06:19:21 +00:00

* Use initialize_device_settings in all nodes * Set StreamHandler level to INFO * Add latest docstring and tutorial changes * work in progress * Standardize device initialization * Add latest docstring and tutorial changes * Adapt device initialization in Reader's train method Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
30 lines
987 B
Python
30 lines
987 B
Python
import logging
|
|
|
|
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
|
from haystack.modeling.model.language_model import LanguageModel
|
|
from haystack.modeling.model.prediction_head import QuestionAnsweringHead
|
|
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
|
|
|
|
|
|
def test_prediction_head_load_save(tmp_path, caplog=None):
|
|
if caplog:
|
|
caplog.set_level(logging.CRITICAL)
|
|
|
|
set_all_seeds(seed=42)
|
|
devices, n_gpu = initialize_device_settings(use_cuda=False)
|
|
lang_model = "bert-base-german-cased"
|
|
|
|
language_model = LanguageModel.load(lang_model)
|
|
prediction_head = QuestionAnsweringHead()
|
|
|
|
model = AdaptiveModel(
|
|
language_model=language_model,
|
|
prediction_heads=[prediction_head],
|
|
embeds_dropout_prob=0.1,
|
|
lm_output_types=["per_sequence"],
|
|
device=devices[0])
|
|
|
|
model.save(tmp_path)
|
|
model_loaded = AdaptiveModel.load(tmp_path, device='cpu')
|
|
assert model_loaded is not None
|