| 
									
										
										
										
											2021-09-09 11:54:47 +02:00
										 |  |  | import logging | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-13 18:38:14 +02:00
										 |  |  | from haystack.modeling.model.adaptive_model import AdaptiveModel | 
					
						
							| 
									
										
										
										
											2022-07-22 16:29:30 +02:00
										 |  |  | from haystack.modeling.model.language_model import get_language_model | 
					
						
							| 
									
										
										
										
											2021-09-13 18:38:14 +02:00
										 |  |  | from haystack.modeling.model.prediction_head import QuestionAnsweringHead | 
					
						
							|  |  |  | from haystack.modeling.utils import set_all_seeds, initialize_device_settings | 
					
						
							| 
									
										
										
										
											2021-09-09 11:54:47 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_prediction_head_load_save(tmp_path, caplog=None): | 
					
						
							|  |  |  |     if caplog: | 
					
						
							|  |  |  |         caplog.set_level(logging.CRITICAL) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     set_all_seeds(seed=42) | 
					
						
							| 
									
										
										
										
											2021-11-09 12:44:20 +01:00
										 |  |  |     devices, n_gpu = initialize_device_settings(use_cuda=False) | 
					
						
							| 
									
										
										
										
											2021-09-09 11:54:47 +02:00
										 |  |  |     lang_model = "bert-base-german-cased" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-22 16:29:30 +02:00
										 |  |  |     language_model = get_language_model(lang_model) | 
					
						
							| 
									
										
										
										
											2021-09-09 11:54:47 +02:00
										 |  |  |     prediction_head = QuestionAnsweringHead() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model = AdaptiveModel( | 
					
						
							|  |  |  |         language_model=language_model, | 
					
						
							|  |  |  |         prediction_heads=[prediction_head], | 
					
						
							|  |  |  |         embeds_dropout_prob=0.1, | 
					
						
							|  |  |  |         lm_output_types=["per_sequence"], | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |         device=devices[0], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-09-09 11:54:47 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model.save(tmp_path) | 
					
						
							| 
									
										
										
										
											2022-02-03 13:43:18 +01:00
										 |  |  |     model_loaded = AdaptiveModel.load(tmp_path, device="cpu") | 
					
						
							| 
									
										
										
										
											2021-09-09 11:54:47 +02:00
										 |  |  |     assert model_loaded is not None |