diff --git a/.gitignore b/.gitignore index 5452f579a..7a9d68060 100644 --- a/.gitignore +++ b/.gitignore @@ -128,5 +128,16 @@ dmypy.json # Pyre type checker .pyre/ +# PyCharm +.idea + # haystack files haystack/database/qa.db +data +mlruns +src +tutorials/cache +tutorials/mlruns +tutorials/model +model + diff --git a/haystack/__init__.py b/haystack/__init__.py index 03c66a91d..e230e7420 100644 --- a/haystack/__init__.py +++ b/haystack/__init__.py @@ -8,6 +8,7 @@ pd.options.display.max_colwidth = 80 logger = logging.getLogger(__name__) logging.getLogger('farm').setLevel(logging.WARNING) +logging.getLogger('farm.infer').setLevel(logging.INFO) logging.getLogger('transformers').setLevel(logging.WARNING) diff --git a/haystack/api/inference.py b/haystack/api/inference.py index 31577bca2..7768d41ce 100644 --- a/haystack/api/inference.py +++ b/haystack/api/inference.py @@ -38,7 +38,7 @@ if len(model_paths) == 0: retriever = TfidfRetriever() FINDERS = {} for idx, model_dir in enumerate(model_paths, start=1): - reader = FARMReader(model_dir=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU) + reader = FARMReader(model_name_or_path=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU) FINDERS[idx] = Finder(reader, retriever) logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'") diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 448fed517..4f346cfec 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -1,36 +1,146 @@ -from farm.infer import Inferencer import numpy as np from scipy.special import expit +from pathlib import Path +import logging + +from farm.data_handler.data_silo import DataSilo +from farm.data_handler.processor import SquadProcessor +from farm.infer import Inferencer +from farm.modeling.optimization import initialize_optimizer +from farm.train import Trainer +from farm.utils import set_all_seeds, initialize_device_settings + +logger = logging.getLogger(__name__) class FARMReader: """ - Implementation of FARM Inferencer for Question Answering. + Transformer based model for extractive Question Answering using the FARM framework (https://github.com/deepset-ai/FARM). + While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same. - The class loads a saved FARM adaptive model from a given directory and runs - inference using `inference_from_dicts()` method. + With a FARMReader, you can: + - directly get predictions via predict() + - fine-tune the model on QA data via train() """ def __init__( self, - model_dir, - context_size=30, - no_answer_shift=-100, + model_name_or_path, + context_window_size=30, + no_ans_threshold=-100, batch_size=16, use_gpu=True, - n_best_per_passage=2 - ): + n_candidates_per_passage=2): """ - Load a saved FARM model in Inference mode. - - :param model_dir: directory path of the saved model + :param model_name_or_path: directory of a saved model or the name of a public model: + - 'bert-base-cased' + - 'deepset/bert-base-cased-squad2' + - 'deepset/bert-base-cased-squad2' + - 'distilbert-base-uncased-distilled-squad' + .... + See https://huggingface.co/models for full list of available models. + :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. + :param no_ans_threshold: How much greater the no_answer logit needs to be over the pos_answer in order to be chosen. + The higher the value, the more `uncertain` answers are accepted + :param batch_size: Number of samples the model receives in one batch for inference + :param use_gpu: Whether to use GPU (if available) + :param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`). + Note: This is not the number of "final answers" you will receive + (see `top_k` in FARMReader.predict() or Finder.get_answers() for that) """ - self.model = Inferencer.load(model_dir, batch_size=batch_size, gpu=use_gpu) - self.model.model.prediction_heads[0].context_size = context_size - self.model.model.prediction_heads[0].no_answer_shift = no_answer_shift - self.model.model.prediction_heads[0].n_best = n_best_per_passage + self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering") + self.inferencer.model.prediction_heads[0].context_window_size = context_window_size + self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_threshold + self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage + + def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None, + use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5, + max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None): + """ + Fine-tune a model on a QA dataset. Options: + - Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data) + - Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool) + + :param data_dir: Path to directory containing your training data in SQuAD style + :param train_filename: filename of training data + :param dev_filename: filename of dev / eval data + :param test_file_name: filename of test data + :param dev_split: Instead of specifying a dev_filename you can also specify a ratio (e.g. 0.1) here + that get's split off from training data for eval. + :param use_gpu: Whether to use GPU (if available) + :param batch_size: Number of samples the model receives in one batch for training + :param n_epochs: number of iterations on the whole training data set + :param learning_rate: learning rate of the optimizer + :param max_seq_len: maximum text length (in tokens). Everything longer gets cut down. + :param warmup_proportion: Proportion of training steps until maximum learning rate is reached. + Until that point LR is increasing linearly. After that it's decreasing again linearly. + Options for different schedules are available in FARM. + :param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset + :param save_dir: Path to store the final model + :return: None + """ + + + if dev_filename: + dev_split = None + + set_all_seeds(seed=42) + device, n_gpu = initialize_device_settings(use_cuda=use_gpu) + + if not save_dir: + save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}" + save_dir = Path(save_dir) + + # 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset + label_list = ["start_token", "end_token"] + metric = "squad" + processor = SquadProcessor( + tokenizer=self.inferencer.processor.tokenizer, + max_seq_len=max_seq_len, + label_list=label_list, + metric=metric, + train_filename=train_filename, + dev_filename=dev_filename, + dev_split=dev_split, + test_filename=test_file_name, + data_dir=Path(data_dir), + ) + + # 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them + # and calculates a few descriptive statistics of our datasets + data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False) + + # 3. Create an optimizer and pass the already initialized model + model, optimizer, lr_schedule = initialize_optimizer( + model=self.inferencer.model, + learning_rate=learning_rate, + schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion}, + n_batches=len(data_silo.loaders["train"]), + n_epochs=n_epochs, + device=device + ) + # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time + trainer = Trainer( + model=model, + optimizer=optimizer, + data_silo=data_silo, + epochs=n_epochs, + n_gpu=n_gpu, + lr_schedule=lr_schedule, + evaluate_every=evaluate_every, + device=device, + ) + # 5. Let it grow! + self.inferencer.model = trainer.train() + self.save(save_dir) + + def save(self, directory): + logger.info(f"Saving reader model to {directory}") + self.inferencer.model.save(directory) + self.inferencer.processor.save(directory) + def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1): """ Use loaded QA model to find answers for a question in the supplied paragraphs. @@ -74,7 +184,7 @@ class FARMReader: input_dicts.append(cur) # get answers from QA model (Top 5 per input paragraph) - predictions = self.model.inference_from_dicts( + predictions = self.inferencer.inference_from_dicts( dicts=input_dicts, rest_api_schema=True, max_processes=max_processes ) diff --git a/haystack/reader/transformers.py b/haystack/reader/transformers.py index 0c8520fb7..440f1a6d6 100644 --- a/haystack/reader/transformers.py +++ b/haystack/reader/transformers.py @@ -3,8 +3,12 @@ from transformers import pipeline class TransformersReader: """ - A reader using the QA Pipeline class from huggingface's Transformers. - Easily load any of the pretrained community QA models from here: https://huggingface.co/models + Transformer based model for extractive Question Answering using the huggingface's transformers framework + (https://github.com/huggingface/transformers). + While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same. + + With the reader, you can: + - directly get predictions via predict() """ def __init__( diff --git a/tutorials/Tutorial1_Basic_QA_Pipeline.py b/tutorials/Tutorial1_Basic_QA_Pipeline.py index 6648535a5..31d3b428e 100755 --- a/tutorials/Tutorial1_Basic_QA_Pipeline.py +++ b/tutorials/Tutorial1_Basic_QA_Pipeline.py @@ -33,7 +33,7 @@ retriever = TfidfRetriever() # A reader scans the text chunks in detail and extracts the k best answers # Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0 fetch_archive_from_http(url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz", output_dir="model") -reader = FARMReader(model_dir="model/bert-english-qa-large", use_gpu=False) +reader = FARMReader(model_name_or_path="model/bert-english-qa-large", use_gpu=False) # OR: use alternatively a reader from huggingface's Transformers package # reader = TransformersReader(use_gpu=-1) diff --git a/tutorials/Tutorial2_Finetune_a_model_on_your_data.py b/tutorials/Tutorial2_Finetune_a_model_on_your_data.py new file mode 100755 index 000000000..ff3b1550d --- /dev/null +++ b/tutorials/Tutorial2_Finetune_a_model_on_your_data.py @@ -0,0 +1,47 @@ +from haystack.reader.farm import FARMReader +from haystack.reader.transformers import TransformersReader +from haystack.retriever.tfidf import TfidfRetriever +from haystack import Finder +from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http +from haystack.indexing.cleaning import clean_wiki_text +from haystack.utils import print_answers + +#### TRAINING ############# +# Let's take a reader as a base model +reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=False) + +# and fine-tune it on your own custom dataset (should be in SQuAD like format) +reader.train(data_dir="../data/squad_small", train_filename="train.json", use_gpu=False, n_epochs=1) + + +#### Use it (same as in Tutorial 1) ############# + +# Okay, we have a fine-tuned model now. Let's test it on some docs: +## Let's get some docs for testing (see Tutorial 1 for more explanations) +from haystack.database import db +db.create_all() + +# Download docs +doc_dir = "data/article_txt_got" +s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip" +fetch_archive_from_http(url=s3_url, output_dir=doc_dir) + +# Write docs to our DB. +write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True) + +# Initialize Finder Pipeline +retriever = TfidfRetriever() +finder = Finder(reader, retriever) + +## Voilá! Ask a question! +# You can configure how many candidates the reader and retriever shall return +# The higher top_k_retriever, the better (but also the slower) your answers. +prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5) + +#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5) +#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5) + +print_answers(prediction, details="minimal") + + +