mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 01:39:45 +00:00 
			
		
		
		
	fix: Ensure eval mode for farm and transformer models for predictions (#3791)
Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
		
							parent
							
								
									97d5db3b9c
								
							
						
					
					
						commit
						1777b22fcb
					
				| @ -290,6 +290,7 @@ class Trainer: | |||||||
|                 ) |                 ) | ||||||
|                 self.test_result = evaluator_test.eval(self.model) |                 self.test_result = evaluator_test.eval(self.model) | ||||||
|                 evaluator_test.log_results(self.test_result, "Test", self.global_step) |                 evaluator_test.log_results(self.test_result, "Test", self.global_step) | ||||||
|  |         self.model.eval() | ||||||
|         return self.model |         return self.model | ||||||
| 
 | 
 | ||||||
|     def compute_loss(self, batch: dict, step: int) -> torch.Tensor: |     def compute_loss(self, batch: dict, step: int) -> torch.Tensor: | ||||||
|  | |||||||
| @ -431,6 +431,36 @@ def test_no_answer_reader_skips_empty_documents(no_answer_reader): | |||||||
|     assert predictions["answers"][1][1].answer == "Carla"  # answer given for 2nd query as usual |     assert predictions["answers"][1][1].answer == "Carla"  # answer given for 2nd query as usual | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @pytest.mark.integration | ||||||
|  | def test_reader_training_returns_eval(tmp_path, samples_path): | ||||||
|  |     max_seq_len = 16 | ||||||
|  |     max_query_length = 8 | ||||||
|  |     reader = FARMReader( | ||||||
|  |         model_name_or_path="deepset/tinyroberta-squad2", | ||||||
|  |         use_gpu=False, | ||||||
|  |         num_processes=0, | ||||||
|  |         max_seq_len=max_seq_len, | ||||||
|  |         doc_stride=2, | ||||||
|  |         max_query_length=max_query_length, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     save_dir = f"{tmp_path}/test_dpr_training" | ||||||
|  |     reader.train( | ||||||
|  |         data_dir=str(samples_path / "squad"), | ||||||
|  |         train_filename="tiny.json", | ||||||
|  |         dev_filename="tiny.json", | ||||||
|  |         n_epochs=1, | ||||||
|  |         batch_size=1, | ||||||
|  |         grad_acc_steps=1, | ||||||
|  |         evaluate_every=0, | ||||||
|  |         save_dir=save_dir, | ||||||
|  |         max_seq_len=max_seq_len, | ||||||
|  |         max_query_length=max_query_length, | ||||||
|  |     ) | ||||||
|  |     assert reader.inferencer.model.training is False | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.integration | ||||||
| def test_reader_training(tmp_path, samples_path): | def test_reader_training(tmp_path, samples_path): | ||||||
|     max_seq_len = 16 |     max_seq_len = 16 | ||||||
|     max_query_length = 8 |     max_query_length = 8 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sebastian
						Sebastian