mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 09:49:48 +00:00 
			
		
		
		
	Redo changing of inti
This commit is contained in:
		
							parent
							
								
									f681026a56
								
							
						
					
					
						commit
						d15448f60a
					
				| @ -23,6 +23,49 @@ class FARMReader: | |||||||
|      - fine-tune the model on QA data via train() |      - fine-tune the model on QA data via train() | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_name_or_path, | ||||||
|  |         context_window_size=30, | ||||||
|  |         batch_size=50, | ||||||
|  |         use_gpu=True, | ||||||
|  |         no_ans_boost=None, | ||||||
|  |         n_candidates_per_paragraph=1): | ||||||
|  |         """ | ||||||
|  |         :param model_name_or_path: directory of a saved model or the name of a public model: | ||||||
|  |                                    - 'bert-base-cased' | ||||||
|  |                                    - 'deepset/bert-base-cased-squad2' | ||||||
|  |                                    - 'deepset/bert-base-cased-squad2' | ||||||
|  |                                    - 'distilbert-base-uncased-distilled-squad' | ||||||
|  |                                    .... | ||||||
|  |                                    See https://huggingface.co/models for full list of available models. | ||||||
|  |         :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. | ||||||
|  |         :param batch_size: Number of samples the model receives in one batch for inference | ||||||
|  |                            Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used. | ||||||
|  |         :param use_gpu: Whether to use GPU (if available) | ||||||
|  |         :param no_ans_boost: How much the no_answer logit is boosted/increased. | ||||||
|  |                              Possible values: None (default) = disable returning "no answer" predictions | ||||||
|  |                                               Negative = lower chance of "no answer" being predicted | ||||||
|  |                                               Positive = increase chance of "no answer" | ||||||
|  |         :param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). | ||||||
|  |                                            Note: - This is not the number of "final answers" you will receive | ||||||
|  |                                                    (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) | ||||||
|  |                                                  - FARM includes no_answer in the sorted list of predictions | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         if no_ans_boost is None: | ||||||
|  |             no_ans_boost = 0 | ||||||
|  |             self.return_no_answers = False | ||||||
|  |         else: | ||||||
|  |             self.return_no_answers = True | ||||||
|  |         self.n_candidates_per_paragraph = n_candidates_per_paragraph | ||||||
|  |         self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") | ||||||
|  |         self.inferencer.model.prediction_heads[0].context_window_size = context_window_size | ||||||
|  |         self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost | ||||||
|  |         self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer | ||||||
|  | 
 | ||||||
|     def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, |     def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, | ||||||
|               use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5, |               use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5, | ||||||
|               max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): |               max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): | ||||||
| @ -104,49 +147,6 @@ class FARMReader: | |||||||
|         self.inferencer.model = trainer.train() |         self.inferencer.model = trainer.train() | ||||||
|         self.save(save_dir) |         self.save(save_dir) | ||||||
| 
 | 
 | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         model_name_or_path, |  | ||||||
|         context_window_size=30, |  | ||||||
|         batch_size=50, |  | ||||||
|         use_gpu=True, |  | ||||||
|         no_ans_boost=None, |  | ||||||
|         n_candidates_per_paragraph=1): |  | ||||||
|         """ |  | ||||||
|         :param model_name_or_path: directory of a saved model or the name of a public model: |  | ||||||
|                                    - 'bert-base-cased' |  | ||||||
|                                    - 'deepset/bert-base-cased-squad2' |  | ||||||
|                                    - 'deepset/bert-base-cased-squad2' |  | ||||||
|                                    - 'distilbert-base-uncased-distilled-squad' |  | ||||||
|                                    .... |  | ||||||
|                                    See https://huggingface.co/models for full list of available models. |  | ||||||
|         :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. |  | ||||||
|         :param batch_size: Number of samples the model receives in one batch for inference |  | ||||||
|                            Memory consumption is much lower in inference mode. Recommendation: increase the batch size to a value so only a single batch is used. |  | ||||||
|         :param use_gpu: Whether to use GPU (if available) |  | ||||||
|         :param no_ans_boost: How much the no_answer logit is boosted/increased. |  | ||||||
|                              Possible values: None (default) = disable returning "no answer" predictions |  | ||||||
|                                               Negative = lower chance of "no answer" being predicted |  | ||||||
|                                               Positive = increase chance of "no answer" |  | ||||||
|         :param n_candidates_per_paragraph: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). |  | ||||||
|                                            Note: - This is not the number of "final answers" you will receive |  | ||||||
|                                                    (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) |  | ||||||
|                                                  - FARM includes no_answer in the sorted list of predictions |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         """ |  | ||||||
| 
 |  | ||||||
|         if no_ans_boost is None: |  | ||||||
|             no_ans_boost = 0 |  | ||||||
|             self.return_no_answers = False |  | ||||||
|         else: |  | ||||||
|             self.return_no_answers = True |  | ||||||
|         self.n_candidates_per_paragraph = n_candidates_per_paragraph |  | ||||||
|         self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") |  | ||||||
|         self.inferencer.model.prediction_heads[0].context_window_size = context_window_size |  | ||||||
|         self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost |  | ||||||
|         self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_paragraph + 1 # including possible no_answer |  | ||||||
| 
 |  | ||||||
|     def save(self, directory): |     def save(self, directory): | ||||||
|         logger.info(f"Saving reader model to {directory}") |         logger.info(f"Saving reader model to {directory}") | ||||||
|         self.inferencer.model.save(directory) |         self.inferencer.model.save(directory) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 timoeller
						timoeller