mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-01 12:23:31 +00:00
Add method to train a reader on custom data (#5)
* initial version of training a reader WIP * update for latest changes in FARM inferencer. Update tutorial. Add basic docs
This commit is contained in:
parent
8a48cd7dd6
commit
1718ea55b8
11
.gitignore
vendored
11
.gitignore
vendored
@ -128,5 +128,16 @@ dmypy.json
|
|||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
.idea
|
||||||
|
|
||||||
# haystack files
|
# haystack files
|
||||||
haystack/database/qa.db
|
haystack/database/qa.db
|
||||||
|
data
|
||||||
|
mlruns
|
||||||
|
src
|
||||||
|
tutorials/cache
|
||||||
|
tutorials/mlruns
|
||||||
|
tutorials/model
|
||||||
|
model
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ pd.options.display.max_colwidth = 80
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
logging.getLogger('farm').setLevel(logging.WARNING)
|
logging.getLogger('farm').setLevel(logging.WARNING)
|
||||||
|
logging.getLogger('farm.infer').setLevel(logging.INFO)
|
||||||
logging.getLogger('transformers').setLevel(logging.WARNING)
|
logging.getLogger('transformers').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ if len(model_paths) == 0:
|
|||||||
retriever = TfidfRetriever()
|
retriever = TfidfRetriever()
|
||||||
FINDERS = {}
|
FINDERS = {}
|
||||||
for idx, model_dir in enumerate(model_paths, start=1):
|
for idx, model_dir in enumerate(model_paths, start=1):
|
||||||
reader = FARMReader(model_dir=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU)
|
reader = FARMReader(model_name_or_path=str(model_dir), batch_size=BATCH_SIZE, use_gpu=USE_GPU)
|
||||||
FINDERS[idx] = Finder(reader, retriever)
|
FINDERS[idx] = Finder(reader, retriever)
|
||||||
logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'")
|
logger.info(f"Initialized Finder (ID={idx}) with model '{model_dir}'")
|
||||||
|
|
||||||
|
@ -1,36 +1,146 @@
|
|||||||
from farm.infer import Inferencer
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.special import expit
|
from scipy.special import expit
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from farm.data_handler.data_silo import DataSilo
|
||||||
|
from farm.data_handler.processor import SquadProcessor
|
||||||
|
from farm.infer import Inferencer
|
||||||
|
from farm.modeling.optimization import initialize_optimizer
|
||||||
|
from farm.train import Trainer
|
||||||
|
from farm.utils import set_all_seeds, initialize_device_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FARMReader:
|
class FARMReader:
|
||||||
"""
|
"""
|
||||||
Implementation of FARM Inferencer for Question Answering.
|
Transformer based model for extractive Question Answering using the FARM framework (https://github.com/deepset-ai/FARM).
|
||||||
|
While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same.
|
||||||
|
|
||||||
The class loads a saved FARM adaptive model from a given directory and runs
|
With a FARMReader, you can:
|
||||||
inference using `inference_from_dicts()` method.
|
- directly get predictions via predict()
|
||||||
|
- fine-tune the model on QA data via train()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_dir,
|
model_name_or_path,
|
||||||
context_size=30,
|
context_window_size=30,
|
||||||
no_answer_shift=-100,
|
no_ans_threshold=-100,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
n_best_per_passage=2
|
n_candidates_per_passage=2):
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Load a saved FARM model in Inference mode.
|
:param model_name_or_path: directory of a saved model or the name of a public model:
|
||||||
|
- 'bert-base-cased'
|
||||||
:param model_dir: directory path of the saved model
|
- 'deepset/bert-base-cased-squad2'
|
||||||
|
- 'deepset/bert-base-cased-squad2'
|
||||||
|
- 'distilbert-base-uncased-distilled-squad'
|
||||||
|
....
|
||||||
|
See https://huggingface.co/models for full list of available models.
|
||||||
|
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
|
||||||
|
:param no_ans_threshold: How much greater the no_answer logit needs to be over the pos_answer in order to be chosen.
|
||||||
|
The higher the value, the more `uncertain` answers are accepted
|
||||||
|
:param batch_size: Number of samples the model receives in one batch for inference
|
||||||
|
:param use_gpu: Whether to use GPU (if available)
|
||||||
|
:param n_candidates_per_passage: How many candidate answers are extracted per text sequence that the model can process at once (depends on `max_seq_len`).
|
||||||
|
Note: This is not the number of "final answers" you will receive
|
||||||
|
(see `top_k` in FARMReader.predict() or Finder.get_answers() for that)
|
||||||
"""
|
"""
|
||||||
self.model = Inferencer.load(model_dir, batch_size=batch_size, gpu=use_gpu)
|
|
||||||
self.model.model.prediction_heads[0].context_size = context_size
|
|
||||||
self.model.model.prediction_heads[0].no_answer_shift = no_answer_shift
|
|
||||||
self.model.model.prediction_heads[0].n_best = n_best_per_passage
|
|
||||||
|
|
||||||
|
|
||||||
|
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, task_type="question_answering")
|
||||||
|
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||||
|
self.inferencer.model.prediction_heads[0].no_ans_threshold = no_ans_threshold
|
||||||
|
self.inferencer.model.prediction_heads[0].n_best = n_candidates_per_passage
|
||||||
|
|
||||||
|
def train(self, data_dir, train_filename, dev_filename=None, test_file_name=None,
|
||||||
|
use_gpu=True, batch_size=10, n_epochs=2, learning_rate=1e-5,
|
||||||
|
max_seq_len=256, warmup_proportion=0.2, dev_split=0.1, evaluate_every=300, save_dir=None):
|
||||||
|
"""
|
||||||
|
Fine-tune a model on a QA dataset. Options:
|
||||||
|
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
|
||||||
|
- Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
|
||||||
|
|
||||||
|
:param data_dir: Path to directory containing your training data in SQuAD style
|
||||||
|
:param train_filename: filename of training data
|
||||||
|
:param dev_filename: filename of dev / eval data
|
||||||
|
:param test_file_name: filename of test data
|
||||||
|
:param dev_split: Instead of specifying a dev_filename you can also specify a ratio (e.g. 0.1) here
|
||||||
|
that get's split off from training data for eval.
|
||||||
|
:param use_gpu: Whether to use GPU (if available)
|
||||||
|
:param batch_size: Number of samples the model receives in one batch for training
|
||||||
|
:param n_epochs: number of iterations on the whole training data set
|
||||||
|
:param learning_rate: learning rate of the optimizer
|
||||||
|
:param max_seq_len: maximum text length (in tokens). Everything longer gets cut down.
|
||||||
|
:param warmup_proportion: Proportion of training steps until maximum learning rate is reached.
|
||||||
|
Until that point LR is increasing linearly. After that it's decreasing again linearly.
|
||||||
|
Options for different schedules are available in FARM.
|
||||||
|
:param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset
|
||||||
|
:param save_dir: Path to store the final model
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
if dev_filename:
|
||||||
|
dev_split = None
|
||||||
|
|
||||||
|
set_all_seeds(seed=42)
|
||||||
|
device, n_gpu = initialize_device_settings(use_cuda=use_gpu)
|
||||||
|
|
||||||
|
if not save_dir:
|
||||||
|
save_dir = f"../../saved_models/{self.inferencer.model.language_model.name}"
|
||||||
|
save_dir = Path(save_dir)
|
||||||
|
|
||||||
|
# 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
|
||||||
|
label_list = ["start_token", "end_token"]
|
||||||
|
metric = "squad"
|
||||||
|
processor = SquadProcessor(
|
||||||
|
tokenizer=self.inferencer.processor.tokenizer,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
label_list=label_list,
|
||||||
|
metric=metric,
|
||||||
|
train_filename=train_filename,
|
||||||
|
dev_filename=dev_filename,
|
||||||
|
dev_split=dev_split,
|
||||||
|
test_filename=test_file_name,
|
||||||
|
data_dir=Path(data_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
|
||||||
|
# and calculates a few descriptive statistics of our datasets
|
||||||
|
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False)
|
||||||
|
|
||||||
|
# 3. Create an optimizer and pass the already initialized model
|
||||||
|
model, optimizer, lr_schedule = initialize_optimizer(
|
||||||
|
model=self.inferencer.model,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion},
|
||||||
|
n_batches=len(data_silo.loaders["train"]),
|
||||||
|
n_epochs=n_epochs,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
data_silo=data_silo,
|
||||||
|
epochs=n_epochs,
|
||||||
|
n_gpu=n_gpu,
|
||||||
|
lr_schedule=lr_schedule,
|
||||||
|
evaluate_every=evaluate_every,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# 5. Let it grow!
|
||||||
|
self.inferencer.model = trainer.train()
|
||||||
|
self.save(save_dir)
|
||||||
|
|
||||||
|
def save(self, directory):
|
||||||
|
logger.info(f"Saving reader model to {directory}")
|
||||||
|
self.inferencer.model.save(directory)
|
||||||
|
self.inferencer.processor.save(directory)
|
||||||
|
|
||||||
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1):
|
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1):
|
||||||
"""
|
"""
|
||||||
Use loaded QA model to find answers for a question in the supplied paragraphs.
|
Use loaded QA model to find answers for a question in the supplied paragraphs.
|
||||||
@ -74,7 +184,7 @@ class FARMReader:
|
|||||||
input_dicts.append(cur)
|
input_dicts.append(cur)
|
||||||
|
|
||||||
# get answers from QA model (Top 5 per input paragraph)
|
# get answers from QA model (Top 5 per input paragraph)
|
||||||
predictions = self.model.inference_from_dicts(
|
predictions = self.inferencer.inference_from_dicts(
|
||||||
dicts=input_dicts, rest_api_schema=True, max_processes=max_processes
|
dicts=input_dicts, rest_api_schema=True, max_processes=max_processes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,8 +3,12 @@ from transformers import pipeline
|
|||||||
|
|
||||||
class TransformersReader:
|
class TransformersReader:
|
||||||
"""
|
"""
|
||||||
A reader using the QA Pipeline class from huggingface's Transformers.
|
Transformer based model for extractive Question Answering using the huggingface's transformers framework
|
||||||
Easily load any of the pretrained community QA models from here: https://huggingface.co/models
|
(https://github.com/huggingface/transformers).
|
||||||
|
While the underlying model can vary (BERT, Roberta, DistilBERT ...) the interface remains the same.
|
||||||
|
|
||||||
|
With the reader, you can:
|
||||||
|
- directly get predictions via predict()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -33,7 +33,7 @@ retriever = TfidfRetriever()
|
|||||||
# A reader scans the text chunks in detail and extracts the k best answers
|
# A reader scans the text chunks in detail and extracts the k best answers
|
||||||
# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0
|
# Reader use more powerful but slower deep learning models, here: a BERT QA model trained via FARM on Squad 2.0
|
||||||
fetch_archive_from_http(url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz", output_dir="model")
|
fetch_archive_from_http(url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-models/0.3.0/bert-english-qa-large.tar.gz", output_dir="model")
|
||||||
reader = FARMReader(model_dir="model/bert-english-qa-large", use_gpu=False)
|
reader = FARMReader(model_name_or_path="model/bert-english-qa-large", use_gpu=False)
|
||||||
|
|
||||||
# OR: use alternatively a reader from huggingface's Transformers package
|
# OR: use alternatively a reader from huggingface's Transformers package
|
||||||
# reader = TransformersReader(use_gpu=-1)
|
# reader = TransformersReader(use_gpu=-1)
|
||||||
|
47
tutorials/Tutorial2_Finetune_a_model_on_your_data.py
Executable file
47
tutorials/Tutorial2_Finetune_a_model_on_your_data.py
Executable file
@ -0,0 +1,47 @@
|
|||||||
|
from haystack.reader.farm import FARMReader
|
||||||
|
from haystack.reader.transformers import TransformersReader
|
||||||
|
from haystack.retriever.tfidf import TfidfRetriever
|
||||||
|
from haystack import Finder
|
||||||
|
from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http
|
||||||
|
from haystack.indexing.cleaning import clean_wiki_text
|
||||||
|
from haystack.utils import print_answers
|
||||||
|
|
||||||
|
#### TRAINING #############
|
||||||
|
# Let's take a reader as a base model
|
||||||
|
reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=False)
|
||||||
|
|
||||||
|
# and fine-tune it on your own custom dataset (should be in SQuAD like format)
|
||||||
|
reader.train(data_dir="../data/squad_small", train_filename="train.json", use_gpu=False, n_epochs=1)
|
||||||
|
|
||||||
|
|
||||||
|
#### Use it (same as in Tutorial 1) #############
|
||||||
|
|
||||||
|
# Okay, we have a fine-tuned model now. Let's test it on some docs:
|
||||||
|
## Let's get some docs for testing (see Tutorial 1 for more explanations)
|
||||||
|
from haystack.database import db
|
||||||
|
db.create_all()
|
||||||
|
|
||||||
|
# Download docs
|
||||||
|
doc_dir = "data/article_txt_got"
|
||||||
|
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt.zip"
|
||||||
|
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
|
||||||
|
|
||||||
|
# Write docs to our DB.
|
||||||
|
write_documents_to_db(document_dir=doc_dir, clean_func=clean_wiki_text, only_empty_db=True)
|
||||||
|
|
||||||
|
# Initialize Finder Pipeline
|
||||||
|
retriever = TfidfRetriever()
|
||||||
|
finder = Finder(reader, retriever)
|
||||||
|
|
||||||
|
## Voilá! Ask a question!
|
||||||
|
# You can configure how many candidates the reader and retriever shall return
|
||||||
|
# The higher top_k_retriever, the better (but also the slower) your answers.
|
||||||
|
prediction = finder.get_answers(question="Who is the father of Arya Stark?", top_k_retriever=10, top_k_reader=5)
|
||||||
|
|
||||||
|
#prediction = finder.get_answers(question="Who created the Dothraki vocabulary?", top_k_reader=5)
|
||||||
|
#prediction = finder.get_answers(question="Who is the sister of Sansa?", top_k_reader=5)
|
||||||
|
|
||||||
|
print_answers(prediction, details="minimal")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user