mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 03:46:30 +00:00
Add method to train a reader on custom data (#5)
* initial version of training a reader WIP * update for latest changes in FARM inferencer. Update tutorial. Add basic docs
This commit is contained in:
parent
8a48cd7dd6
commit
1718ea55b8
11
.gitignore
vendored
11
.gitignore
vendored
@ -128,5 +128,16 @@ dmypy.json
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# PyCharm
|
||||
.idea
|
||||
|
||||
# haystack files
|
||||
haystack/database/qa.db
|
||||
data
|
||||
mlruns
|
||||
src
|
||||
tutorials/cache
|
||||
tutorials/mlruns
|
||||
tutorials/model
|
||||
model
|
||||
|
||||
|
@ -8,6 +8,7 @@ pd.options.display.max_colwidth = 80
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logging.getLogger('farm').setLevel(logging.WARNING)
|
||||
logging.getLogger('farm.infer').setLevel(logging.INFO)
|
||||
logging.getLogger('transformers').setLevel(logging.WARNING)
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ if len(model_paths) == 0:
|
||||
retriever = TfidfRetriever()
|
||||
FINDERS = {}
|
||||
for idx, model_dir in enumerate(model_paths, start=1):
|
||||
reader = FARMReader(model_dir=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU)
|
||||
reader = FARMReader(model_name_or_path=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU)
|
||||
FINDERS[idx] = Finder(reader, retriever)
|
||||
logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'")
|
||||
|
||||
|
@ -1,36 +1,146 @@
|
||||
from farm.infer import Inferencer
|
||||
import numpy as np
|
||||
from scipy.special import expit
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from farm.data_handler.data_silo import DataSilo
|
||||
from farm.data_handler.processor import SquadProcessor
|
||||
from farm.infer import Inferencer
|
||||
from farm.modeling.optimization import initialize_optimizer
|
||||
from farm.train import Trainer
|
||||
from farm.utils import set_all_seeds, initialize_device_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FARMReader:
|
||||
"""
|
||||
Implementation of FARM Inferencer for Question Answering.
|
||||
Transformer based model for extractive Question Answering using the FARM framework (https://github.com/deepset-ai/FARM).
|
||||
While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same.
|
||||
|
||||
The class loads a saved FARM adaptive model from a given directory and runs
|
||||
inference using `inference_from_dicts()` method.
|
||||
With a FARMReader, you can:
|
||||
- directly get predictions via predict()
|
||||
- fine-tune the model on QA data via train()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir,
|
||||
context_size=30,
|
||||
no_answer_shift=-100,
|
||||
model_name_or_path,
|
||||
context_window_size=30,
|
||||
no_ans_threshold=-100,
|
||||
batch_size=16,
|
||||
use_gpu=True,
|
||||
n_best_per_passage=2
|
||||
):
|
||||
n_candidates_per_passage=2):
|
||||
"""
|
||||
Load a saved FARM model in Inference mode.
|
||||
|
||||
:param model_dir: directory path of the saved model
|
||||
:param model_name_or_path: directory of a saved model or the name of a public model:
|
||||
- 'bert-base-cased'
|
||||
- 'deepset/bert-base-cased-squad2'
|
||||
- 'deepset/bert-base-cased-squad2'
|
||||
- 'distilbert-base-uncased-distilled-squad'
|
||||
....
|
||||
See https://huggingface.co/models for full list of available models.
|
||||
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
|
||||
:param no_ans_threshold: How much greater the no_answer logit needs to be over the pos_answer in order to be chosen.
|
||||
The higher the value, the more `uncertain` answers are accepted
|
||||
:param batch_size: Number of samples the model receives in one batch for inference
|
||||
:param use_gpu: Whether to use GPU (if available)
|
||||
:param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
|
||||
Note: This is not the number of "final answers" you will receive
|
||||
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
|
||||
"""
|
||||
self.model = Inferencer.load(model_dir, batch_size=batch_size, gpu=use_gpu)
|
||||
self.model.model.prediction_heads[0].context_size = context_size
|
||||
self.model.model.prediction_heads[0].no_answer_shift = no_answer_shift
|
||||
self.model.model.prediction_heads[0].n_best = n_best_per_passage
|
||||
|
||||
|
||||
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
|
||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||
self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_threshold
|
||||
self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage
|
||||
|
||||
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
|
||||
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
|
||||
max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
|
||||
"""
|
||||
Fine-tune a model on a QA dataset. Options:
|
||||
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
|
||||
- Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
|
||||
|
||||
:param data_dir: Path to directory containing your training data in SQuAD style
|
||||
:param train_filename: filename of training data
|
||||
:param dev_filename: filename of dev / eval data
|
||||
:param test_file_name: filename of test data
|
||||
:param dev_split: Instead of specifying a dev_filename you can also specify a ratio (e.g. 0.1) here
|
||||
that get's split off from training data for eval.
|
||||
:param use_gpu: Whether to use GPU (if available)
|
||||
:param batch_size: Number of samples the model receives in one batch for training
|
||||
:param n_epochs: number of iterations on the whole training data set
|
||||
:param learning_rate: learning rate of the optimizer
|
||||
:param max_seq_len: maximum text length (in tokens). Everything longer gets cut down.
|
||||
:param warmup_proportion: Proportion of training steps until maximum learning rate is reached.
|
||||
Until that point LR is increasing linearly. After that it's decreasing again linearly.
|
||||
Options for different schedules are available in FARM.
|
||||
:param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset
|
||||
:param save_dir: Path to store the final model
|
||||
:return: None
|
||||
"""
|
||||
|
||||
|
||||
if dev_filename:
|
||||
dev_split = None
|
||||
|
||||
set_all_seeds(seed=42)
|
||||
device, n_gpu = initialize_device_settings(use_cuda=use_gpu)
|
||||
|
||||
if not save_dir:
|
||||
save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}"
|
||||
save_dir = Path(save_dir)
|
||||
|
||||
# 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
|
||||
label_list = ["start_token", "end_token"]
|
||||
metric = "squad"
|
||||
processor = SquadProcessor(
|
||||
tokenizer=self.inferencer.processor.tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
label_list=label_list,
|
||||
metric=metric,
|
||||
train_filename=train_filename,
|
||||
dev_filename=dev_filename,
|
||||
dev_split=dev_split,
|
||||
test_filename=test_file_name,
|
||||
data_dir=Path(data_dir),
|
||||
)
|
||||
|
||||
# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
|
||||
# and calculates a few descriptive statistics of our datasets
|
||||
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False)
|
||||
|
||||
# 3. Create an optimizer and pass the already initialized model
|
||||
model, optimizer, lr_schedule = initialize_optimizer(
|
||||
model=self.inferencer.model,
|
||||
learning_rate=learning_rate,
|
||||
schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion},
|
||||
n_batches=len(data_silo.loaders["train"]),
|
||||
n_epochs=n_epochs,
|
||||
device=device
|
||||
)
|
||||
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
data_silo=data_silo,
|
||||
epochs=n_epochs,
|
||||
n_gpu=n_gpu,
|
||||
lr_schedule=lr_schedule,
|
||||
evaluate_every=evaluate_every,
|
||||
device=device,
|
||||
)
|
||||
# 5. Let it grow!
|
||||
self.inferencer.model = trainer.train()
|
||||
self.save(save_dir)
|
||||
|
||||
def save(self, directory):
|
||||
logger.info(f"Saving reader model to {directory}")
|
||||
self.inferencer.model.save(directory)
|
||||
self.inferencer.processor.save(directory)
|
||||
|
||||
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1):
|
||||
"""
|
||||
Use loaded QA model to find answers for a question in the supplied paragraphs.
|
||||
@ -74,7 +184,7 @@ class FARMReader:
|
||||
input_dicts.append(cur)
|
||||
|
||||
# get answers from QA model (Top 5 per input paragraph)
|
||||
predictions = self.model.inference_from_dicts(
|
||||
predictions = self.inferencer.inference_from_dicts(
|
||||
dicts=input_dicts, rest_api_schema=True, max_processes=max_processes
|
||||
)
|
||||
|
||||
|
@ -3,8 +3,12 @@ from transformers import pipeline
|
||||
|
||||
class TransformersReader:
|
||||
"""
|
||||
A reader using the QA Pipeline class from huggingface's Transformers.
|
||||
Easily load any of the pretrained community QA models from here: https://huggingface.co/models
|
||||
Transformer based model for extractive Question Answering using the huggingface's transformers framework
|
||||
(https://github.com/huggingface/transformers).
|
||||
While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same.
|
||||
|
||||
With the reader, you can:
|
||||
- directly get predictions via predict()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -33,7 +33,7 @@ retriever = TfidfRetriever()
|
||||
# A reader scans the text chunks in detail and extracts the k best answers
|
||||
# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0
|
||||
fetch_archive_from_http(url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz", output_dir="model")
|
||||
reader = FARMReader(model_dir="model/bert-english-qa-large", use_gpu=False)
|
||||
reader = FARMReader(model_name_or_path="model/bert-english-qa-large", use_gpu=False)
|
||||
|
||||
# OR: use alternatively a reader from huggingface's Transformers package
|
||||
# reader = TransformersReader(use_gpu=-1)
|
||||
|
47
tutorials/Tutorial2_Finetune_a_model_on_your_data.py
Executable file
47
tutorials/Tutorial2_Finetune_a_model_on_your_data.py
Executable file
@ -0,0 +1,47 @@
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.reader.transformers import TransformersReader
|
||||
from haystack.retriever.tfidf import TfidfRetriever
|
||||
from haystack import Finder
|
||||
from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http
|
||||
from haystack.indexing.cleaning import clean_wiki_text
|
||||
from haystack.utils import print_answers
|
||||
|
||||
#### TRAINING #############
|
||||
# Let's take a reader as a base model
|
||||
reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=False)
|
||||
|
||||
# and fine-tune it on your own custom dataset (should be in SQuAD like format)
|
||||
reader.train(data_dir="../data/squad_small", train_filename="train.json", use_gpu=False, n_epochs=1)
|
||||
|
||||
|
||||
#### Use it (same as in Tutorial 1) #############
|
||||
|
||||
# Okay, we have a fine-tuned model now. Let's test it on some docs:
|
||||
## Let's get some docs for testing (see Tutorial 1 for more explanations)
|
||||
from haystack.database import db
|
||||
db.create_all()
|
||||
|
||||
# Download docs
|
||||
doc_dir = "data/article_txt_got"
|
||||
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip"
|
||||
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
|
||||
|
||||
# Write docs to our DB.
|
||||
write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)
|
||||
|
||||
# Initialize Finder Pipeline
|
||||
retriever = TfidfRetriever()
|
||||
finder = Finder(reader, retriever)
|
||||
|
||||
## Voilá! Ask a question!
|
||||
# You can configure how many candidates the reader and retriever shall return
|
||||
# The higher top_k_retriever, the better (but also the slower) your answers.
|
||||
prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5)
|
||||
|
||||
#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5)
|
||||
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)
|
||||
|
||||
print_answers(prediction, details="minimal")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user