mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 01:39:45 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			419 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			419 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | |
|  "cells": [
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "# Training Your Own \"Dense Passage Retrieval\" Model\n",
 | |
|     "\n",
 | |
|     "[](https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial9_DPR_training.ipynb)\n",
 | |
|     "\n",
 | |
|     "Haystack contains all the tools needed to train your own Dense Passage Retrieval model.\n",
 | |
|     "This tutorial will guide you through the steps required to create a retriever that is specifically tailored to your domain."
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%% md\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Install the latest release of Haystack in your own environment\n",
 | |
|     "#! pip install farm-haystack\n",
 | |
|     "\n",
 | |
|     "# Install the latest master of Haystack\n",
 | |
|     "!pip install git+https://github.com/deepset-ai/haystack.git\n",
 | |
|     "!pip install torch==1.6.0+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Here are some imports that we'll need\n",
 | |
|     "\n",
 | |
|     "from haystack.retriever.dense import DensePassageRetriever\n",
 | |
|     "from haystack.preprocessor.utils import fetch_archive_from_http\n",
 | |
|     "from haystack.document_store.memory import InMemoryDocumentStore"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "## Training Data\n",
 | |
|     "\n",
 | |
|     "DPR training performed using Information Retrieval data.\n",
 | |
|     "More specifically, you want to feed in pairs of queries and relevant documents.\n",
 | |
|     "\n",
 | |
|     "To train a model, we will need a dataset that has the same format as the original DPR training data.\n",
 | |
|     "Each data point in the dataset should have the following dictionary structure.\n",
 | |
|     "\n",
 | |
|     "``` python\n",
 | |
|     "    {\n",
 | |
|     "        \"dataset\": str,\n",
 | |
|     "        \"question\": str,\n",
 | |
|     "        \"answers\": list of str\n",
 | |
|     "        \"positive_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
 | |
|     "        \"negative_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
 | |
|     "        \"hard_negative_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
 | |
|     "    }\n",
 | |
|     "```\n",
 | |
|     "\n",
 | |
|     "`positive_ctxs` are context passages which are relevant to the query.\n",
 | |
|     "In some datasets, queries might have more than one positive context\n",
 | |
|     "in which case you can set the `num_positives` parameter to be higher than the default 1.\n",
 | |
|     "Note that `num_positives` needs to be lower or equal to the minimum number of `positive_ctxs` for queries in your data.\n",
 | |
|     "If you have an unequal number of positive contexts per example,\n",
 | |
|     "you might want to generate some soft labels by retrieving similar contexts which contain the answer.\n",
 | |
|     "\n",
 | |
|     "DPR is standardly trained using a method known as in-batch negatives.\n",
 | |
|     "This means that positive contexts for a given query are treated as negative contexts for the other queries in the batch.\n",
 | |
|     "Doing so allows for a high degree of computational efficiency, thus allowing the model to be trained on large amounts of data.\n",
 | |
|     "\n",
 | |
|     "`negative_ctxs` is not actually used in Haystack's DPR training so we recommend you set it to an empty list.\n",
 | |
|     "They were used by the original DPR authors in an experiment to compare it against the in-batch negatives method.\n",
 | |
|     "\n",
 | |
|     "`hard_negative_ctxs` are passages that are not relevant to the query.\n",
 | |
|     "In the original DPR paper, these are fetched using a retriever to find the most relevant passages to the query.\n",
 | |
|     "Passages which contain the answer text are filtered out.\n",
 | |
|     "\n",
 | |
|     "We are [currently working](https://github.com/deepset-ai/haystack/issues/705) on a script that will convert SQuAD format data into a DPR dataset!"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%% md\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "## Using Question Answering Data\n",
 | |
|     "\n",
 | |
|     "Question Answering datasets can sometimes be used as training data.\n",
 | |
|     "Google's Natural Questions dataset, is sufficiently large\n",
 | |
|     "and contains enough unique passages, that it can be converted into a DPR training set.\n",
 | |
|     "This is done simply by considering answer containing passages as relevant documents to the query.\n",
 | |
|     "\n",
 | |
|     "The SQuAD dataset, however, is not as suited to this use case since its question and answer pairs\n",
 | |
|     "are created on only a very small slice of wikipedia documents."
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%% md\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "## Download Original DPR Training Data\n",
 | |
|     "\n",
 | |
|     "WARNING: These files are large! The train set is 7.4GB and the dev set is 800MB\n",
 | |
|     "\n",
 | |
|     "We can download the original DPR training data with the following cell.\n",
 | |
|     "Note that this data is probably only useful if you are trying to train from scratch."
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%% md\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Download original DPR data\n",
 | |
|     "# WARNING: the train set is 7.4GB and the dev set is 800MB\n",
 | |
|     "\n",
 | |
|     "doc_dir = \"data/dpr_training/\"\n",
 | |
|     "\n",
 | |
|     "s3_url_train = \"https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz\"\n",
 | |
|     "s3_url_dev = \"https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz\"\n",
 | |
|     "\n",
 | |
|     "fetch_archive_from_http(s3_url_train, output_dir=doc_dir + \"train/\")\n",
 | |
|     "fetch_archive_from_http(s3_url_dev, output_dir=doc_dir + \"dev/\")"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "## Option 1: Training DPR from Scratch\n",
 | |
|     "\n",
 | |
|     "The default variables that we provide below are chosen to train a DPR model from scratch.\n",
 | |
|     "Here, both passage and query embedding models are initialized using BERT base\n",
 | |
|     "and the model is trained using Google's Natural Questions dataset (in a format specialised for DPR).\n",
 | |
|     "\n",
 | |
|     "If you are working in a language other than English,\n",
 | |
|     "you will want to initialize the passage and query embedding models with a language model that supports your language\n",
 | |
|     "and also provide a dataset in your language."
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%% md\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Here are the variables to specify our training data, the models that we use to initialize DPR\n",
 | |
|     "# and the directory where we'll be saving the model\n",
 | |
|     "\n",
 | |
|     "doc_dir = \"data/dpr_training/\"\n",
 | |
|     "\n",
 | |
|     "train_filename = \"train/biencoder-nq-train.json\"\n",
 | |
|     "dev_filename = \"dev/biencoder-nq-dev.json\"\n",
 | |
|     "\n",
 | |
|     "query_model = \"bert-base-uncased\"\n",
 | |
|     "passage_model = \"bert-base-uncased\"\n",
 | |
|     "\n",
 | |
|     "save_dir = \"../saved_models/dpr\""
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "## Option 2: Finetuning DPR\n",
 | |
|     "\n",
 | |
|     "If you have your own domain specific question answering or information retrieval dataset,\n",
 | |
|     "you might instead be interested in finetuning a pretrained DPR model.\n",
 | |
|     "In this case, you would initialize both query and passage models using the original pretrained model.\n",
 | |
|     "You will want to load something like this set of variables instead of the ones above"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%% md\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Here are the variables you might want to use instead of the set above\n",
 | |
|     "# in order to perform pretraining\n",
 | |
|     "\n",
 | |
|     "doc_dir = \"PATH_TO_YOUR_DATA_DIR\"\n",
 | |
|     "train_filename = \"TRAIN_FILENAME\"\n",
 | |
|     "dev_filename = \"DEV_FILENAME\"\n",
 | |
|     "\n",
 | |
|     "query_model = \"facebook/dpr-question_encoder-single-nq-base\"\n",
 | |
|     "passage_model = \"facebook/dpr-ctx_encoder-single-nq-base\"\n",
 | |
|     "\n",
 | |
|     "save_dir = \"../saved_models/dpr\""
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "## Initialization\n",
 | |
|     "\n",
 | |
|     "Here we want to initialize our model either with plain language model weights for training from scratch\n",
 | |
|     "or else with pretrained DPR weights for finetuning.\n",
 | |
|     "We follow the [original DPR parameters](https://github.com/facebookresearch/DPR#best-hyperparameter-settings)\n",
 | |
|     "for their max passage length but set max query length to 64 since queries are very rarely longer."
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%% md\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "## Initialize DPR model\n",
 | |
|     "\n",
 | |
|     "retriever = DensePassageRetriever(\n",
 | |
|     "    document_store=InMemoryDocumentStore(),\n",
 | |
|     "    query_embedding_model=query_model,\n",
 | |
|     "    passage_embedding_model=passage_model,\n",
 | |
|     "    max_seq_len_query=64,\n",
 | |
|     "    max_seq_len_passage=256\n",
 | |
|     ")"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "## Training\n",
 | |
|     "\n",
 | |
|     "Let's start training and save our trained model!\n",
 | |
|     "\n",
 | |
|     "On a V100 GPU, you can fit up to batch size 4 so we set gradient accumulation steps to 4 in order\n",
 | |
|     "to simulate the batch size 16 of the original DPR experiment.\n",
 | |
|     "\n",
 | |
|     "When `embed_title=True`, the document title is prepended to the input text sequence with a `[SEP]` token\n",
 | |
|     "between it and document text."
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%% md\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "When training from scratch with the above variables, 1 epoch takes around an hour and we reached the following performance:\n",
 | |
|     "\n",
 | |
|     "```\n",
 | |
|     "loss: 0.09334952129693501\n",
 | |
|     "acc: 0.984035000191887\n",
 | |
|     "f1: 0.936147352264006\n",
 | |
|     "acc_and_f1: 0.9600911762279465\n",
 | |
|     "average_rank: 0.07075978511128166\n",
 | |
|     "```"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%% md\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Start training our model and save it when it is finished\n",
 | |
|     "\n",
 | |
|     "retriever.train(\n",
 | |
|     "    data_dir=doc_dir,\n",
 | |
|     "    train_filename=train_filename,\n",
 | |
|     "    dev_filename=dev_filename,\n",
 | |
|     "    test_filename=dev_filename,\n",
 | |
|     "    n_epochs=1,\n",
 | |
|     "    batch_size=4,\n",
 | |
|     "    grad_acc_steps=4,\n",
 | |
|     "    save_dir=save_dir,\n",
 | |
|     "    evaluate_every=3000,\n",
 | |
|     "    embed_title=True,\n",
 | |
|     "    num_positives=1,\n",
 | |
|     "    num_hard_negatives=1\n",
 | |
|     ")"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "## Loading\n",
 | |
|     "\n",
 | |
|     "Loading our newly trained model is simple!"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%% md\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "reloaded_retriever = DensePassageRetriever.load(load_dir=save_dir, document_store=None)"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   }
 | |
|  ],
 | |
|  "metadata": {
 | |
|   "kernelspec": {
 | |
|    "display_name": "Python 3",
 | |
|    "language": "python",
 | |
|    "name": "python3"
 | |
|   },
 | |
|   "language_info": {
 | |
|    "codemirror_mode": {
 | |
|     "name": "ipython",
 | |
|     "version": 2
 | |
|    },
 | |
|    "file_extension": ".py",
 | |
|    "mimetype": "text/x-python",
 | |
|    "name": "python",
 | |
|    "nbconvert_exporter": "python",
 | |
|    "pygments_lexer": "ipython2",
 | |
|    "version": "2.7.6"
 | |
|   },
 | |
|   "pycharm": {
 | |
|    "stem_cell": {
 | |
|     "cell_type": "raw",
 | |
|     "source": [],
 | |
|     "metadata": {
 | |
|      "collapsed": false
 | |
|     }
 | |
|    }
 | |
|   }
 | |
|  },
 | |
|  "nbformat": 4,
 | |
|  "nbformat_minor": 0
 | |
| } | 
