mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-22 16:31:16 +00:00
400 lines
12 KiB
Plaintext
400 lines
12 KiB
Plaintext
![]() |
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"# Training Your Own \"Dense Passage Retrieval\" Model\n",
|
||
|
"\n",
|
||
|
"[](https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial9_DPR_training.ipynb)\n",
|
||
|
"\n",
|
||
|
"Haystack contains all the tools needed to train your own Dense Passage Retrieval model.\n",
|
||
|
"This tutorial will guide you through the steps required to create a retriever that is specifically tailored to your domain.\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%% md\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Here are some imports that we'll need\n",
|
||
|
"\n",
|
||
|
"from haystack.retriever.dense import DensePassageRetriever\n",
|
||
|
"from haystack.preprocessor.utils import fetch_archive_from_http\n",
|
||
|
"from haystack.document_store.memory import InMemoryDocumentStore"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%%\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"## Training Data\n",
|
||
|
"\n",
|
||
|
"DPR training performed using Information Retrieval data.\n",
|
||
|
"More specifically, you want to feed in pairs of queries and relevant documents.\n",
|
||
|
"\n",
|
||
|
"To train a model, we will need a dataset that has the same format as the original DPR training data.\n",
|
||
|
"Each data point in the dataset should have the following dictionary structure.\n",
|
||
|
"\n",
|
||
|
"``` python\n",
|
||
|
" {\n",
|
||
|
" \"dataset\": str,\n",
|
||
|
" \"question\": str,\n",
|
||
|
" \"answers\": list of str\n",
|
||
|
" \"positive_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
|
||
|
" \"negative_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
|
||
|
" \"hard_negative_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
|
||
|
" }\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"`positive_ctxs` are context passages which are relevant to the query.\n",
|
||
|
"In some datasets, queries might have more than one positive context\n",
|
||
|
"in which case you can set the `num_positives` parameter to be higher than the default 1.\n",
|
||
|
"Note that `num_positives` needs to be lower or equal to the minimum number of `positive_ctxs` for queries in your data.\n",
|
||
|
"If you have an unequal number of positive contexts per example,\n",
|
||
|
"you might want to generate some soft labels by retrieving similar contexts which contain the answer.\n",
|
||
|
"\n",
|
||
|
"DPR is standardly trained using a method known as in-batch negatives.\n",
|
||
|
"This means that positive contexts for a given query are treated as negative contexts for the other queries in the batch.\n",
|
||
|
"Doing so allows for a high degree of computational efficiency, thus allowing the model to be trained on large amounts of data.\n",
|
||
|
"\n",
|
||
|
"`negative_ctxs` is not actually used in Haystack's DPR training so we recommend you set it to an empty list.\n",
|
||
|
"They were used by the original DPR authors in an experiment to compare it against the in-batch negatives method.\n",
|
||
|
"\n",
|
||
|
"`hard_negative_ctxs` are passages that are not relevant to the query.\n",
|
||
|
"In the original DPR paper, these are fetched using a retriever to find the most relevant passages to the query.\n",
|
||
|
"Passages which contain the answer text are filtered out.\n",
|
||
|
"\n",
|
||
|
"We are [currently working](https://github.com/deepset-ai/haystack/issues/705) on a script that will convert SQuAD format data into a DPR dataset!"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%% md\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"## Using Question Answering Data\n",
|
||
|
"\n",
|
||
|
"Question Answering datasets can sometimes be used as training data.\n",
|
||
|
"Google's Natural Questions dataset, is sufficiently large\n",
|
||
|
"and contains enough unique passages, that it can be converted into a DPR training set.\n",
|
||
|
"This is done simply by considering answer containing passages as relevant documents to the query.\n",
|
||
|
"\n",
|
||
|
"The SQuAD dataset, however, is not as suited to this use case since its question and answer pairs\n",
|
||
|
"are created on only a very small slice of wikipedia documents."
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%% md\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"## Download Original DPR Training Data\n",
|
||
|
"\n",
|
||
|
"WARNING: These files are large! The train set is 7.4GB and the dev set is 800MB\n",
|
||
|
"\n",
|
||
|
"We can download the original DPR training data with the following cell.\n",
|
||
|
"Note that this data is probably only useful if you are trying to train from scratch."
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%% md\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Download original DPR data\n",
|
||
|
"# WARNING: the train set is 7.4GB and the dev set is 800MB\n",
|
||
|
"\n",
|
||
|
"doc_dir = \"data/dpr_training/\"\n",
|
||
|
"\n",
|
||
|
"s3_url_train = \"https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz\"\n",
|
||
|
"s3_url_dev = \"https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz\"\n",
|
||
|
"\n",
|
||
|
"fetch_archive_from_http(s3_url_train, output_dir=doc_dir + \"train/\")\n",
|
||
|
"fetch_archive_from_http(s3_url_dev, output_dir=doc_dir + \"dev/\")"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%%\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"## Option 1: Training DPR from Scratch\n",
|
||
|
"\n",
|
||
|
"The default variables that we provide below are chosen to train a DPR model from scratch.\n",
|
||
|
"Here, both passage and query embedding models are initialized using BERT base\n",
|
||
|
"and the model is trained using Google's Natural Questions dataset (in a format specialised for DPR).\n",
|
||
|
"\n",
|
||
|
"If you are working in a language other than English,\n",
|
||
|
"you will want to initialize the passage and query embedding models with a language model that supports your language\n",
|
||
|
"and also provide a dataset in your language."
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%% md\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Here are the variables to specify our training data, the models that we use to initialize DPR\n",
|
||
|
"# and the directory where we'll be saving the model\n",
|
||
|
"\n",
|
||
|
"doc_dir = \"data/dpr_training/\"\n",
|
||
|
"\n",
|
||
|
"train_filename = \"train/biencoder-nq-train.json\"\n",
|
||
|
"dev_filename = \"dev/biencoder-nq-dev.json\"\n",
|
||
|
"\n",
|
||
|
"query_model = \"bert-base-uncased\"\n",
|
||
|
"passage_model = \"bert-base-uncased\"\n",
|
||
|
"\n",
|
||
|
"save_dir = \"../saved_models/dpr\""
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%%\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"## Option 2: Finetuning DPR\n",
|
||
|
"\n",
|
||
|
"If you have your own domain specific question answering or information retrieval dataset,\n",
|
||
|
"you might instead be interested in finetuning a pretrained DPR model.\n",
|
||
|
"In this case, you would initialize both query and passage models using the original pretrained model.\n",
|
||
|
"You will want to load something like this set of variables instead of the ones above"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%% md\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Here are the variables you might want to use instead of the set above\n",
|
||
|
"# in order to perform pretraining\n",
|
||
|
"\n",
|
||
|
"doc_dir = \"PATH_TO_YOUR_DATA_DIR\"\n",
|
||
|
"train_filename = \"TRAIN_FILENAME\"\n",
|
||
|
"dev_filename = \"DEV_FILENAME\"\n",
|
||
|
"\n",
|
||
|
"query_model = \"facebook/dpr-question_encoder-single-nq-base\"\n",
|
||
|
"passage_model = \"facebook/dpr-ctx_encoder-single-nq-base\"\n",
|
||
|
"\n",
|
||
|
"save_dir = \"../saved_models/dpr\""
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%%\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"## Initialization\n",
|
||
|
"\n",
|
||
|
"Here we want to initialize our model either with plain language model weights for training from scratch\n",
|
||
|
"or else with pretrained DPR weights for finetuning.\n",
|
||
|
"We follow the [original DPR parameters](https://github.com/facebookresearch/DPR#best-hyperparameter-settings)\n",
|
||
|
"for their max passage length but set max query length to 64 since queries are very rarely longer."
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%% md\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"## Initialize DPR model\n",
|
||
|
"\n",
|
||
|
"retriever = DensePassageRetriever(\n",
|
||
|
" document_store=InMemoryDocumentStore(),\n",
|
||
|
" query_embedding_model=query_model,\n",
|
||
|
" passage_embedding_model=passage_model,\n",
|
||
|
" max_seq_len_query=64,\n",
|
||
|
" max_seq_len_passage=256\n",
|
||
|
")"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%%\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"## Training\n",
|
||
|
"\n",
|
||
|
"Let's start training and save our trained model!\n",
|
||
|
"\n",
|
||
|
"On a V100 GPU, you can fit up to batch size 4 so we set gradient accumulation steps to 4 in order\n",
|
||
|
"to simulate the batch size 16 of the original DPR experiment.\n",
|
||
|
"\n",
|
||
|
"When `embed_title=True`, the document title is prepended to the input text sequence with a `[SEP]` token\n",
|
||
|
"between it and document text."
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%% md\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"When training from scratch with the above variables, 1 epoch takes around an hour and we reached the following performance:\n",
|
||
|
"\n",
|
||
|
"```\n",
|
||
|
"loss: 0.09334952129693501\n",
|
||
|
"acc: 0.984035000191887\n",
|
||
|
"f1: 0.936147352264006\n",
|
||
|
"acc_and_f1: 0.9600911762279465\n",
|
||
|
"average_rank: 0.07075978511128166\n",
|
||
|
"```"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%% md\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Start training our model and save it when it is finished\n",
|
||
|
"\n",
|
||
|
"retriever.train(\n",
|
||
|
" data_dir=doc_dir,\n",
|
||
|
" train_filename=dev_filename,\n",
|
||
|
" dev_filename=dev_filename,\n",
|
||
|
" test_filename=dev_filename,\n",
|
||
|
" n_epochs=1,\n",
|
||
|
" batch_size=4,\n",
|
||
|
" grad_acc_steps=4,\n",
|
||
|
" save_dir=save_dir,\n",
|
||
|
" evaluate_every=3000,\n",
|
||
|
" embed_title=True,\n",
|
||
|
" num_positives=1,\n",
|
||
|
" num_hard_negatives=1\n",
|
||
|
")"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%%\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"## Loading\n",
|
||
|
"\n",
|
||
|
"Loading our newly trained model is simple!"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%% md\n"
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"reloaded_retriever = DensePassageRetriever.load(load_dir=save_dir, document_store=None)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"pycharm": {
|
||
|
"name": "#%%\n"
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 2
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython2",
|
||
|
"version": "2.7.6"
|
||
|
},
|
||
|
"pycharm": {
|
||
|
"stem_cell": {
|
||
|
"cell_type": "raw",
|
||
|
"source": [],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 0
|
||
|
}
|