mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 01:39:45 +00:00 
			
		
		
		
	fix: Ensure eval mode for farm and transformer models for predictions (#3791)
Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
		
							parent
							
								
									97d5db3b9c
								
							
						
					
					
						commit
						1777b22fcb
					
				| @ -290,6 +290,7 @@ class Trainer: | ||||
|                 ) | ||||
|                 self.test_result = evaluator_test.eval(self.model) | ||||
|                 evaluator_test.log_results(self.test_result, "Test", self.global_step) | ||||
|         self.model.eval() | ||||
|         return self.model | ||||
| 
 | ||||
|     def compute_loss(self, batch: dict, step: int) -> torch.Tensor: | ||||
|  | ||||
| @ -431,6 +431,36 @@ def test_no_answer_reader_skips_empty_documents(no_answer_reader): | ||||
|     assert predictions["answers"][1][1].answer == "Carla"  # answer given for 2nd query as usual | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.integration | ||||
| def test_reader_training_returns_eval(tmp_path, samples_path): | ||||
|     max_seq_len = 16 | ||||
|     max_query_length = 8 | ||||
|     reader = FARMReader( | ||||
|         model_name_or_path="deepset/tinyroberta-squad2", | ||||
|         use_gpu=False, | ||||
|         num_processes=0, | ||||
|         max_seq_len=max_seq_len, | ||||
|         doc_stride=2, | ||||
|         max_query_length=max_query_length, | ||||
|     ) | ||||
| 
 | ||||
|     save_dir = f"{tmp_path}/test_dpr_training" | ||||
|     reader.train( | ||||
|         data_dir=str(samples_path / "squad"), | ||||
|         train_filename="tiny.json", | ||||
|         dev_filename="tiny.json", | ||||
|         n_epochs=1, | ||||
|         batch_size=1, | ||||
|         grad_acc_steps=1, | ||||
|         evaluate_every=0, | ||||
|         save_dir=save_dir, | ||||
|         max_seq_len=max_seq_len, | ||||
|         max_query_length=max_query_length, | ||||
|     ) | ||||
|     assert reader.inferencer.model.training is False | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.integration | ||||
| def test_reader_training(tmp_path, samples_path): | ||||
|     max_seq_len = 16 | ||||
|     max_query_length = 8 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sebastian
						Sebastian