mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-04 03:39:31 +00:00 
			
		
		
		
	Add method to train a reader on custom data (#5)
* initial version of training a reader WIP * update for latest changes in FARM inferencer. Update tutorial. Add basic docs
This commit is contained in:
		
							parent
							
								
									8a48cd7dd6
								
							
						
					
					
						commit
						1718ea55b8
					
				
							
								
								
									
										11
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -128,5 +128,16 @@ dmypy.json
 | 
			
		||||
# Pyre type checker
 | 
			
		||||
.pyre/
 | 
			
		||||
 | 
			
		||||
# PyCharm
 | 
			
		||||
.idea
 | 
			
		||||
 | 
			
		||||
# haystack files
 | 
			
		||||
haystack/database/qa.db
 | 
			
		||||
data
 | 
			
		||||
mlruns
 | 
			
		||||
src
 | 
			
		||||
tutorials/cache
 | 
			
		||||
tutorials/mlruns
 | 
			
		||||
tutorials/model
 | 
			
		||||
model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ pd.options.display.max_colwidth = 80
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
logging.getLogger('farm').setLevel(logging.WARNING)
 | 
			
		||||
logging.getLogger('farm.infer').setLevel(logging.INFO)
 | 
			
		||||
logging.getLogger('transformers').setLevel(logging.WARNING)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,7 @@ if len(model_paths) == 0:
 | 
			
		||||
retriever = TfidfRetriever()
 | 
			
		||||
FINDERS = {}
 | 
			
		||||
for idx, model_dir in enumerate(model_paths, start=1):
 | 
			
		||||
    reader = FARMReader(model_dir=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU)
 | 
			
		||||
    reader = FARMReader(model_name_or_path=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU)
 | 
			
		||||
    FINDERS[idx] = Finder(reader, retriever)
 | 
			
		||||
    logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,36 +1,146 @@
 | 
			
		||||
from farm.infer import Inferencer
 | 
			
		||||
import numpy as np
 | 
			
		||||
from scipy.special import expit
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from farm.data_handler.data_silo import DataSilo
 | 
			
		||||
from farm.data_handler.processor import SquadProcessor
 | 
			
		||||
from farm.infer import Inferencer
 | 
			
		||||
from farm.modeling.optimization import initialize_optimizer
 | 
			
		||||
from farm.train import Trainer
 | 
			
		||||
from farm.utils import set_all_seeds, initialize_device_settings
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FARMReader:
 | 
			
		||||
    """
 | 
			
		||||
    Implementation of FARM Inferencer for Question Answering.
 | 
			
		||||
    Transformer based model for extractive Question Answering using the FARM framework (https://github.com/deepset-ai/FARM).
 | 
			
		||||
    While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same.
 | 
			
		||||
 | 
			
		||||
    The class loads a saved FARM adaptive model from a given directory and runs
 | 
			
		||||
    inference using `inference_from_dicts()` method.
 | 
			
		||||
    With a FARMReader, you can:
 | 
			
		||||
     - directly get predictions via predict()
 | 
			
		||||
     - fine-tune the model on QA data via train()
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        model_dir,
 | 
			
		||||
        context_size=30,
 | 
			
		||||
        no_answer_shift=-100,
 | 
			
		||||
        model_name_or_path,
 | 
			
		||||
        context_window_size=30,
 | 
			
		||||
        no_ans_threshold=-100,
 | 
			
		||||
        batch_size=16,
 | 
			
		||||
        use_gpu=True,
 | 
			
		||||
        n_best_per_passage=2
 | 
			
		||||
    ):
 | 
			
		||||
        n_candidates_per_passage=2):
 | 
			
		||||
        """
 | 
			
		||||
        Load a saved FARM model in Inference mode.
 | 
			
		||||
 | 
			
		||||
        :param model_dir: directory path of the saved model
 | 
			
		||||
        :param model_name_or_path: directory of a saved model or the name of a public model:
 | 
			
		||||
                                   - 'bert-base-cased'
 | 
			
		||||
                                   - 'deepset/bert-base-cased-squad2'
 | 
			
		||||
                                   - 'deepset/bert-base-cased-squad2'
 | 
			
		||||
                                   - 'distilbert-base-uncased-distilled-squad'
 | 
			
		||||
                                   ....
 | 
			
		||||
                                   See https://huggingface.co/models for full list of available models.
 | 
			
		||||
        :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
 | 
			
		||||
        :param no_ans_threshold: How much greater the no_answer logit needs to be over the pos_answer in order to be chosen.
 | 
			
		||||
                                 The higher the value, the more `uncertain` answers are accepted
 | 
			
		||||
        :param batch_size: Number of samples the model receives in one batch for inference
 | 
			
		||||
        :param use_gpu: Whether to use GPU (if available)
 | 
			
		||||
        :param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
 | 
			
		||||
                                         Note: This is not the number of "final answers" you will receive
 | 
			
		||||
                                         (see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
 | 
			
		||||
        """
 | 
			
		||||
        self.model = Inferencer.load(model_dir, batch_size=batch_size, gpu=use_gpu)
 | 
			
		||||
        self.model.model.prediction_heads[0].context_size = context_size
 | 
			
		||||
        self.model.model.prediction_heads[0].no_answer_shift = no_answer_shift
 | 
			
		||||
        self.model.model.prediction_heads[0].n_best = n_best_per_passage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
 | 
			
		||||
        self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
 | 
			
		||||
        self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_threshold
 | 
			
		||||
        self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage
 | 
			
		||||
 | 
			
		||||
    def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
 | 
			
		||||
              use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
 | 
			
		||||
              max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
 | 
			
		||||
        """
 | 
			
		||||
        Fine-tune a model on a QA dataset. Options:
 | 
			
		||||
        - Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
 | 
			
		||||
        - Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
 | 
			
		||||
 | 
			
		||||
        :param data_dir: Path to directory containing your training data in SQuAD style
 | 
			
		||||
        :param train_filename: filename of training data
 | 
			
		||||
        :param dev_filename: filename of dev / eval data
 | 
			
		||||
        :param test_file_name: filename of test data
 | 
			
		||||
        :param dev_split: Instead of specifying a dev_filename you can also specify a ratio (e.g. 0.1) here
 | 
			
		||||
                          that get's split off from training data for eval.
 | 
			
		||||
        :param use_gpu: Whether to use GPU (if available)
 | 
			
		||||
        :param batch_size: Number of samples the model receives in one batch for training
 | 
			
		||||
        :param n_epochs: number of iterations on the whole training data set
 | 
			
		||||
        :param learning_rate: learning rate of the optimizer
 | 
			
		||||
        :param max_seq_len: maximum text length (in tokens). Everything longer gets cut down.
 | 
			
		||||
        :param warmup_proportion: Proportion of training steps until maximum learning rate is reached.
 | 
			
		||||
                                  Until that point LR is increasing linearly. After that it's decreasing again linearly.
 | 
			
		||||
                                  Options for different schedules are available in FARM.
 | 
			
		||||
        :param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset
 | 
			
		||||
        :param save_dir: Path to store the final model
 | 
			
		||||
        :return: None
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if dev_filename:
 | 
			
		||||
            dev_split = None
 | 
			
		||||
 | 
			
		||||
        set_all_seeds(seed=42)
 | 
			
		||||
        device, n_gpu = initialize_device_settings(use_cuda=use_gpu)
 | 
			
		||||
 | 
			
		||||
        if not save_dir:
 | 
			
		||||
            save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}"
 | 
			
		||||
        save_dir = Path(save_dir)
 | 
			
		||||
 | 
			
		||||
        # 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
 | 
			
		||||
        label_list = ["start_token", "end_token"]
 | 
			
		||||
        metric = "squad"
 | 
			
		||||
        processor = SquadProcessor(
 | 
			
		||||
            tokenizer=self.inferencer.processor.tokenizer,
 | 
			
		||||
            max_seq_len=max_seq_len,
 | 
			
		||||
            label_list=label_list,
 | 
			
		||||
            metric=metric,
 | 
			
		||||
            train_filename=train_filename,
 | 
			
		||||
            dev_filename=dev_filename,
 | 
			
		||||
            dev_split=dev_split,
 | 
			
		||||
            test_filename=test_file_name,
 | 
			
		||||
            data_dir=Path(data_dir),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
 | 
			
		||||
        # and calculates a few descriptive statistics of our datasets
 | 
			
		||||
        data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False)
 | 
			
		||||
 | 
			
		||||
        # 3. Create an optimizer and pass the already initialized model
 | 
			
		||||
        model, optimizer, lr_schedule = initialize_optimizer(
 | 
			
		||||
            model=self.inferencer.model,
 | 
			
		||||
            learning_rate=learning_rate,
 | 
			
		||||
            schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion},
 | 
			
		||||
            n_batches=len(data_silo.loaders["train"]),
 | 
			
		||||
            n_epochs=n_epochs,
 | 
			
		||||
            device=device
 | 
			
		||||
        )
 | 
			
		||||
        # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
 | 
			
		||||
        trainer = Trainer(
 | 
			
		||||
            model=model,
 | 
			
		||||
            optimizer=optimizer,
 | 
			
		||||
            data_silo=data_silo,
 | 
			
		||||
            epochs=n_epochs,
 | 
			
		||||
            n_gpu=n_gpu,
 | 
			
		||||
            lr_schedule=lr_schedule,
 | 
			
		||||
            evaluate_every=evaluate_every,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        # 5. Let it grow!
 | 
			
		||||
        self.inferencer.model = trainer.train()
 | 
			
		||||
        self.save(save_dir)
 | 
			
		||||
 | 
			
		||||
    def save(self, directory):
 | 
			
		||||
        logger.info(f"Saving reader model to {directory}")
 | 
			
		||||
        self.inferencer.model.save(directory)
 | 
			
		||||
        self.inferencer.processor.save(directory)
 | 
			
		||||
 | 
			
		||||
    def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1):
 | 
			
		||||
        """
 | 
			
		||||
        Use loaded QA model to find answers for a question in the supplied paragraphs.
 | 
			
		||||
@ -74,7 +184,7 @@ class FARMReader:
 | 
			
		||||
            input_dicts.append(cur)
 | 
			
		||||
 | 
			
		||||
        # get answers from QA model (Top 5 per input paragraph)
 | 
			
		||||
        predictions = self.model.inference_from_dicts(
 | 
			
		||||
        predictions = self.inferencer.inference_from_dicts(
 | 
			
		||||
            dicts=input_dicts, rest_api_schema=True, max_processes=max_processes
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3,8 +3,12 @@ from transformers import pipeline
 | 
			
		||||
 | 
			
		||||
class TransformersReader:
 | 
			
		||||
    """
 | 
			
		||||
    A reader using the QA Pipeline class from huggingface's Transformers.
 | 
			
		||||
    Easily load any of the pretrained community QA models from here: https://huggingface.co/models
 | 
			
		||||
    Transformer based model for extractive Question Answering using the huggingface's transformers framework
 | 
			
		||||
    (https://github.com/huggingface/transformers).
 | 
			
		||||
    While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same.
 | 
			
		||||
 | 
			
		||||
    With the reader, you can:
 | 
			
		||||
     - directly get predictions via predict()
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,7 @@ retriever = TfidfRetriever()
 | 
			
		||||
# A reader scans the text chunks in detail and extracts the k best answers
 | 
			
		||||
# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0
 | 
			
		||||
fetch_archive_from_http(url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz", output_dir="model")
 | 
			
		||||
reader = FARMReader(model_dir="model/bert-english-qa-large", use_gpu=False)
 | 
			
		||||
reader = FARMReader(model_name_or_path="model/bert-english-qa-large", use_gpu=False)
 | 
			
		||||
 | 
			
		||||
# OR: use alternatively a reader from huggingface's Transformers package
 | 
			
		||||
# reader = TransformersReader(use_gpu=-1)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										47
									
								
								tutorials/Tutorial2_Finetune_a_model_on_your_data.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										47
									
								
								tutorials/Tutorial2_Finetune_a_model_on_your_data.py
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,47 @@
 | 
			
		||||
from haystack.reader.farm import FARMReader
 | 
			
		||||
from haystack.reader.transformers import TransformersReader
 | 
			
		||||
from haystack.retriever.tfidf import TfidfRetriever
 | 
			
		||||
from haystack import Finder
 | 
			
		||||
from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http
 | 
			
		||||
from haystack.indexing.cleaning import clean_wiki_text
 | 
			
		||||
from haystack.utils import print_answers
 | 
			
		||||
 | 
			
		||||
#### TRAINING #############
 | 
			
		||||
# Let's take a reader as a base model
 | 
			
		||||
reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=False)
 | 
			
		||||
 | 
			
		||||
# and fine-tune it on your own custom dataset (should be in SQuAD like format)
 | 
			
		||||
reader.train(data_dir="../data/squad_small", train_filename="train.json", use_gpu=False, n_epochs=1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#### Use it (same as in Tutorial 1) #############
 | 
			
		||||
 | 
			
		||||
# Okay, we have a fine-tuned model now. Let's test it on some docs:
 | 
			
		||||
## Let's get some docs for testing (see Tutorial 1 for more explanations)
 | 
			
		||||
from haystack.database import db
 | 
			
		||||
db.create_all()
 | 
			
		||||
 | 
			
		||||
# Download docs
 | 
			
		||||
doc_dir = "data/article_txt_got"
 | 
			
		||||
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip"
 | 
			
		||||
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
 | 
			
		||||
 | 
			
		||||
# Write docs to our DB.
 | 
			
		||||
write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)
 | 
			
		||||
 | 
			
		||||
# Initialize Finder Pipeline
 | 
			
		||||
retriever = TfidfRetriever()
 | 
			
		||||
finder = Finder(reader, retriever)
 | 
			
		||||
 | 
			
		||||
## Voilá! Ask a question!
 | 
			
		||||
# You can configure how many candidates the reader and retriever shall return
 | 
			
		||||
# The higher top_k_retriever, the better (but also the slower) your answers.
 | 
			
		||||
prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5)
 | 
			
		||||
 | 
			
		||||
#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5)
 | 
			
		||||
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)
 | 
			
		||||
 | 
			
		||||
print_answers(prediction, details="minimal")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user