mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 17:59:27 +00:00 
			
		
		
		
	Add method to train a reader on custom data (#5)
* initial version of training a reader WIP * update for latest changes in FARM inferencer. Update tutorial. Add basic docs
This commit is contained in:
		
							parent
							
								
									8a48cd7dd6
								
							
						
					
					
						commit
						1718ea55b8
					
				
							
								
								
									
										11
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -128,5 +128,16 @@ dmypy.json | ||||
| # Pyre type checker | ||||
| .pyre/ | ||||
| 
 | ||||
| # PyCharm | ||||
| .idea | ||||
| 
 | ||||
| # haystack files | ||||
| haystack/database/qa.db | ||||
| data | ||||
| mlruns | ||||
| src | ||||
| tutorials/cache | ||||
| tutorials/mlruns | ||||
| tutorials/model | ||||
| model | ||||
| 
 | ||||
|  | ||||
| @ -8,6 +8,7 @@ pd.options.display.max_colwidth = 80 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| logging.getLogger('farm').setLevel(logging.WARNING) | ||||
| logging.getLogger('farm.infer').setLevel(logging.INFO) | ||||
| logging.getLogger('transformers').setLevel(logging.WARNING) | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -38,7 +38,7 @@ if len(model_paths) == 0: | ||||
| retriever = TfidfRetriever() | ||||
| FINDERS = {} | ||||
| for idx, model_dir in enumerate(model_paths, start=1): | ||||
|     reader = FARMReader(model_dir=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU) | ||||
|     reader = FARMReader(model_name_or_path=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU) | ||||
|     FINDERS[idx] = Finder(reader, retriever) | ||||
|     logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'") | ||||
| 
 | ||||
|  | ||||
| @ -1,36 +1,146 @@ | ||||
| from farm.infer import Inferencer | ||||
| import numpy as np | ||||
| from scipy.special import expit | ||||
| from pathlib import Path | ||||
| import logging | ||||
| 
 | ||||
| from farm.data_handler.data_silo import DataSilo | ||||
| from farm.data_handler.processor import SquadProcessor | ||||
| from farm.infer import Inferencer | ||||
| from farm.modeling.optimization import initialize_optimizer | ||||
| from farm.train import Trainer | ||||
| from farm.utils import set_all_seeds, initialize_device_settings | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class FARMReader: | ||||
|     """ | ||||
|     Implementation of FARM Inferencer for Question Answering. | ||||
|     Transformer based model for extractive Question Answering using the FARM framework (https://github.com/deepset-ai/FARM). | ||||
|     While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same. | ||||
| 
 | ||||
|     The class loads a saved FARM adaptive model from a given directory and runs | ||||
|     inference using `inference_from_dicts()` method. | ||||
|     With a FARMReader, you can: | ||||
|      - directly get predictions via predict() | ||||
|      - fine-tune the model on QA data via train() | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_dir, | ||||
|         context_size=30, | ||||
|         no_answer_shift=-100, | ||||
|         model_name_or_path, | ||||
|         context_window_size=30, | ||||
|         no_ans_threshold=-100, | ||||
|         batch_size=16, | ||||
|         use_gpu=True, | ||||
|         n_best_per_passage=2 | ||||
|     ): | ||||
|         n_candidates_per_passage=2): | ||||
|         """ | ||||
|         Load a saved FARM model in Inference mode. | ||||
| 
 | ||||
|         :param model_dir: directory path of the saved model | ||||
|         :param model_name_or_path: directory of a saved model or the name of a public model: | ||||
|                                    - 'bert-base-cased' | ||||
|                                    - 'deepset/bert-base-cased-squad2' | ||||
|                                    - 'deepset/bert-base-cased-squad2' | ||||
|                                    - 'distilbert-base-uncased-distilled-squad' | ||||
|                                    .... | ||||
|                                    See https://huggingface.co/models for full list of available models. | ||||
|         :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. | ||||
|         :param no_ans_threshold: How much greater the no_answer logit needs to be over the pos_answer in order to be chosen. | ||||
|                                  The higher the value, the more `uncertain` answers are accepted | ||||
|         :param batch_size: Number of samples the model receives in one batch for inference | ||||
|         :param use_gpu: Whether to use GPU (if available) | ||||
|         :param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). | ||||
|                                          Note: This is not the number of "final answers" you will receive | ||||
|                                          (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) | ||||
|         """ | ||||
|         self.model = Inferencer.load(model_dir, batch_size=batch_size, gpu=use_gpu) | ||||
|         self.model.model.prediction_heads[0].context_size = context_size | ||||
|         self.model.model.prediction_heads[0].no_answer_shift = no_answer_shift | ||||
|         self.model.model.prediction_heads[0].n_best = n_best_per_passage | ||||
| 
 | ||||
| 
 | ||||
|         self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") | ||||
|         self.inferencer.model.prediction_heads[0].context_window_size = context_window_size | ||||
|         self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_threshold | ||||
|         self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage | ||||
| 
 | ||||
|     def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, | ||||
|               use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5, | ||||
|               max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): | ||||
|         """ | ||||
|         Fine-tune a model on a QA dataset. Options: | ||||
|         - Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data) | ||||
|         - Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool) | ||||
| 
 | ||||
|         :param data_dir: Path to directory containing your training data in SQuAD style | ||||
|         :param train_filename: filename of training data | ||||
|         :param dev_filename: filename of dev / eval data | ||||
|         :param test_file_name: filename of test data | ||||
|         :param dev_split: Instead of specifying a dev_filename you can also specify a ratio (e.g. 0.1) here | ||||
|                           that get's split off from training data for eval. | ||||
|         :param use_gpu: Whether to use GPU (if available) | ||||
|         :param batch_size: Number of samples the model receives in one batch for training | ||||
|         :param n_epochs: number of iterations on the whole training data set | ||||
|         :param learning_rate: learning rate of the optimizer | ||||
|         :param max_seq_len: maximum text length (in tokens). Everything longer gets cut down. | ||||
|         :param warmup_proportion: Proportion of training steps until maximum learning rate is reached. | ||||
|                                   Until that point LR is increasing linearly. After that it's decreasing again linearly. | ||||
|                                   Options for different schedules are available in FARM. | ||||
|         :param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset | ||||
|         :param save_dir: Path to store the final model | ||||
|         :return: None | ||||
|         """ | ||||
| 
 | ||||
| 
 | ||||
|         if dev_filename: | ||||
|             dev_split = None | ||||
| 
 | ||||
|         set_all_seeds(seed=42) | ||||
|         device, n_gpu = initialize_device_settings(use_cuda=use_gpu) | ||||
| 
 | ||||
|         if not save_dir: | ||||
|             save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}" | ||||
|         save_dir = Path(save_dir) | ||||
| 
 | ||||
|         # 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset | ||||
|         label_list = ["start_token", "end_token"] | ||||
|         metric = "squad" | ||||
|         processor = SquadProcessor( | ||||
|             tokenizer=self.inferencer.processor.tokenizer, | ||||
|             max_seq_len=max_seq_len, | ||||
|             label_list=label_list, | ||||
|             metric=metric, | ||||
|             train_filename=train_filename, | ||||
|             dev_filename=dev_filename, | ||||
|             dev_split=dev_split, | ||||
|             test_filename=test_file_name, | ||||
|             data_dir=Path(data_dir), | ||||
|         ) | ||||
| 
 | ||||
|         # 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them | ||||
|         # and calculates a few descriptive statistics of our datasets | ||||
|         data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False) | ||||
| 
 | ||||
|         # 3. Create an optimizer and pass the already initialized model | ||||
|         model, optimizer, lr_schedule = initialize_optimizer( | ||||
|             model=self.inferencer.model, | ||||
|             learning_rate=learning_rate, | ||||
|             schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion}, | ||||
|             n_batches=len(data_silo.loaders["train"]), | ||||
|             n_epochs=n_epochs, | ||||
|             device=device | ||||
|         ) | ||||
|         # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time | ||||
|         trainer = Trainer( | ||||
|             model=model, | ||||
|             optimizer=optimizer, | ||||
|             data_silo=data_silo, | ||||
|             epochs=n_epochs, | ||||
|             n_gpu=n_gpu, | ||||
|             lr_schedule=lr_schedule, | ||||
|             evaluate_every=evaluate_every, | ||||
|             device=device, | ||||
|         ) | ||||
|         # 5. Let it grow! | ||||
|         self.inferencer.model = trainer.train() | ||||
|         self.save(save_dir) | ||||
| 
 | ||||
|     def save(self, directory): | ||||
|         logger.info(f"Saving reader model to {directory}") | ||||
|         self.inferencer.model.save(directory) | ||||
|         self.inferencer.processor.save(directory) | ||||
| 
 | ||||
|     def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1): | ||||
|         """ | ||||
|         Use loaded QA model to find answers for a question in the supplied paragraphs. | ||||
| @ -74,7 +184,7 @@ class FARMReader: | ||||
|             input_dicts.append(cur) | ||||
| 
 | ||||
|         # get answers from QA model (Top 5 per input paragraph) | ||||
|         predictions = self.model.inference_from_dicts( | ||||
|         predictions = self.inferencer.inference_from_dicts( | ||||
|             dicts=input_dicts, rest_api_schema=True, max_processes=max_processes | ||||
|         ) | ||||
| 
 | ||||
|  | ||||
| @ -3,8 +3,12 @@ from transformers import pipeline | ||||
| 
 | ||||
| class TransformersReader: | ||||
|     """ | ||||
|     A reader using the QA Pipeline class from huggingface's Transformers. | ||||
|     Easily load any of the pretrained community QA models from here: https://huggingface.co/models | ||||
|     Transformer based model for extractive Question Answering using the huggingface's transformers framework | ||||
|     (https://github.com/huggingface/transformers). | ||||
|     While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same. | ||||
| 
 | ||||
|     With the reader, you can: | ||||
|      - directly get predictions via predict() | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|  | ||||
| @ -33,7 +33,7 @@ retriever = TfidfRetriever() | ||||
| # A reader scans the text chunks in detail and extracts the k best answers | ||||
| # Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0 | ||||
| fetch_archive_from_http(url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz", output_dir="model") | ||||
| reader = FARMReader(model_dir="model/bert-english-qa-large", use_gpu=False) | ||||
| reader = FARMReader(model_name_or_path="model/bert-english-qa-large", use_gpu=False) | ||||
| 
 | ||||
| # OR: use alternatively a reader from huggingface's Transformers package | ||||
| # reader = TransformersReader(use_gpu=-1) | ||||
|  | ||||
							
								
								
									
										47
									
								
								tutorials/Tutorial2_Finetune_a_model_on_your_data.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										47
									
								
								tutorials/Tutorial2_Finetune_a_model_on_your_data.py
									
									
									
									
									
										Executable file
									
								
							| @ -0,0 +1,47 @@ | ||||
| from haystack.reader.farm import FARMReader | ||||
| from haystack.reader.transformers import TransformersReader | ||||
| from haystack.retriever.tfidf import TfidfRetriever | ||||
| from haystack import Finder | ||||
| from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http | ||||
| from haystack.indexing.cleaning import clean_wiki_text | ||||
| from haystack.utils import print_answers | ||||
| 
 | ||||
| #### TRAINING ############# | ||||
| # Let's take a reader as a base model | ||||
| reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=False) | ||||
| 
 | ||||
| # and fine-tune it on your own custom dataset (should be in SQuAD like format) | ||||
| reader.train(data_dir="../data/squad_small", train_filename="train.json", use_gpu=False, n_epochs=1) | ||||
| 
 | ||||
| 
 | ||||
| #### Use it (same as in Tutorial 1) ############# | ||||
| 
 | ||||
| # Okay, we have a fine-tuned model now. Let's test it on some docs: | ||||
| ## Let's get some docs for testing (see Tutorial 1 for more explanations) | ||||
| from haystack.database import db | ||||
| db.create_all() | ||||
| 
 | ||||
| # Download docs | ||||
| doc_dir = "data/article_txt_got" | ||||
| s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip" | ||||
| fetch_archive_from_http(url=s3_url, output_dir=doc_dir) | ||||
| 
 | ||||
| # Write docs to our DB. | ||||
| write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True) | ||||
| 
 | ||||
| # Initialize Finder Pipeline | ||||
| retriever = TfidfRetriever() | ||||
| finder = Finder(reader, retriever) | ||||
| 
 | ||||
| ## Voilá! Ask a question! | ||||
| # You can configure how many candidates the reader and retriever shall return | ||||
| # The higher top_k_retriever, the better (but also the slower) your answers. | ||||
| prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5) | ||||
| 
 | ||||
| #prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5) | ||||
| #prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5) | ||||
| 
 | ||||
| print_answers(prediction, details="minimal") | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Malte Pietsch
						Malte Pietsch