mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 17:59:27 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			31 lines
		
	
	
		
			993 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			31 lines
		
	
	
		
			993 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| 
 | |
| from haystack.modeling.model.adaptive_model import AdaptiveModel
 | |
| from haystack.modeling.model.language_model import LanguageModel
 | |
| from haystack.modeling.model.prediction_head import QuestionAnsweringHead
 | |
| from haystack.modeling.utils import set_all_seeds, initialize_device_settings
 | |
| 
 | |
| 
 | |
| def test_prediction_head_load_save(tmp_path, caplog=None):
 | |
|     if caplog:
 | |
|         caplog.set_level(logging.CRITICAL)
 | |
| 
 | |
|     set_all_seeds(seed=42)
 | |
|     devices, n_gpu = initialize_device_settings(use_cuda=False)
 | |
|     lang_model = "bert-base-german-cased"
 | |
| 
 | |
|     language_model = LanguageModel.load(lang_model)
 | |
|     prediction_head = QuestionAnsweringHead()
 | |
| 
 | |
|     model = AdaptiveModel(
 | |
|         language_model=language_model,
 | |
|         prediction_heads=[prediction_head],
 | |
|         embeds_dropout_prob=0.1,
 | |
|         lm_output_types=["per_sequence"],
 | |
|         device=devices[0],
 | |
|     )
 | |
| 
 | |
|     model.save(tmp_path)
 | |
|     model_loaded = AdaptiveModel.load(tmp_path, device="cpu")
 | |
|     assert model_loaded is not None
 | 
