mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-20 15:31:15 +00:00
455 lines
14 KiB
Plaintext
455 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# Training Your Own \"Dense Passage Retrieval\" Model\n",
|
|
"\n",
|
|
"[](https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial9_DPR_training.ipynb)\n",
|
|
"\n",
|
|
"Haystack contains all the tools needed to train your own Dense Passage Retrieval model.\n",
|
|
"This tutorial will guide you through the steps required to create a retriever that is specifically tailored to your domain."
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Install the latest release of Haystack in your own environment\n",
|
|
"#! pip install farm-haystack\n",
|
|
"\n",
|
|
"# Install the latest master of Haystack\n",
|
|
"!pip install grpcio-tools==1.34.1\n",
|
|
"!pip install git+https://github.com/deepset-ai/haystack.git"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Here are some imports that we'll need\n",
|
|
"\n",
|
|
"from haystack.retriever.dense import DensePassageRetriever\n",
|
|
"from haystack.preprocessor.utils import fetch_archive_from_http\n",
|
|
"from haystack.document_store.memory import InMemoryDocumentStore"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Training Data\n",
|
|
"\n",
|
|
"DPR training performed using Information Retrieval data.\n",
|
|
"More specifically, you want to feed in pairs of queries and relevant documents.\n",
|
|
"\n",
|
|
"To train a model, we will need a dataset that has the same format as the original DPR training data.\n",
|
|
"Each data point in the dataset should have the following dictionary structure.\n",
|
|
"\n",
|
|
"``` python\n",
|
|
" {\n",
|
|
" \"dataset\": str,\n",
|
|
" \"question\": str,\n",
|
|
" \"answers\": list of str\n",
|
|
" \"positive_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
|
|
" \"negative_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
|
|
" \"hard_negative_ctxs\": list of dictionaries of format {'title': str, 'text': str, 'score': int, 'title_score': int, 'passage_id': str}\n",
|
|
" }\n",
|
|
"```\n",
|
|
"\n",
|
|
"`positive_ctxs` are context passages which are relevant to the query.\n",
|
|
"In some datasets, queries might have more than one positive context\n",
|
|
"in which case you can set the `num_positives` parameter to be higher than the default 1.\n",
|
|
"Note that `num_positives` needs to be lower or equal to the minimum number of `positive_ctxs` for queries in your data.\n",
|
|
"If you have an unequal number of positive contexts per example,\n",
|
|
"you might want to generate some soft labels by retrieving similar contexts which contain the answer.\n",
|
|
"\n",
|
|
"DPR is standardly trained using a method known as in-batch negatives.\n",
|
|
"This means that positive contexts for a given query are treated as negative contexts for the other queries in the batch.\n",
|
|
"Doing so allows for a high degree of computational efficiency, thus allowing the model to be trained on large amounts of data.\n",
|
|
"\n",
|
|
"`negative_ctxs` is not actually used in Haystack's DPR training so we recommend you set it to an empty list.\n",
|
|
"They were used by the original DPR authors in an experiment to compare it against the in-batch negatives method.\n",
|
|
"\n",
|
|
"`hard_negative_ctxs` are passages that are not relevant to the query.\n",
|
|
"In the original DPR paper, these are fetched using a retriever to find the most relevant passages to the query.\n",
|
|
"Passages which contain the answer text are filtered out.\n",
|
|
"\n",
|
|
"If you'd like to convert your SQuAD format data into something that can train a DPR model,\n",
|
|
"check out the utility script at [`haystack/retriever/squad_to_dpr.py`](https://github.com/deepset-ai/haystack/blob/master/haystack/retriever/squad_to_dpr.py)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Using Question Answering Data\n",
|
|
"\n",
|
|
"Question Answering datasets can sometimes be used as training data.\n",
|
|
"Google's Natural Questions dataset, is sufficiently large\n",
|
|
"and contains enough unique passages, that it can be converted into a DPR training set.\n",
|
|
"This is done simply by considering answer containing passages as relevant documents to the query.\n",
|
|
"\n",
|
|
"The SQuAD dataset, however, is not as suited to this use case since its question and answer pairs\n",
|
|
"are created on only a very small slice of wikipedia documents."
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Download Original DPR Training Data\n",
|
|
"\n",
|
|
"WARNING: These files are large! The train set is 7.4GB and the dev set is 800MB\n",
|
|
"\n",
|
|
"We can download the original DPR training data with the following cell.\n",
|
|
"Note that this data is probably only useful if you are trying to train from scratch."
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Download original DPR data\n",
|
|
"# WARNING: the train set is 7.4GB and the dev set is 800MB\n",
|
|
"\n",
|
|
"doc_dir = \"data/dpr_training/\"\n",
|
|
"\n",
|
|
"s3_url_train = \"https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz\"\n",
|
|
"s3_url_dev = \"https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz\"\n",
|
|
"\n",
|
|
"fetch_archive_from_http(s3_url_train, output_dir=doc_dir + \"train/\")\n",
|
|
"fetch_archive_from_http(s3_url_dev, output_dir=doc_dir + \"dev/\")"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Option 1: Training DPR from Scratch\n",
|
|
"\n",
|
|
"The default variables that we provide below are chosen to train a DPR model from scratch.\n",
|
|
"Here, both passage and query embedding models are initialized using BERT base\n",
|
|
"and the model is trained using Google's Natural Questions dataset (in a format specialised for DPR).\n",
|
|
"\n",
|
|
"If you are working in a language other than English,\n",
|
|
"you will want to initialize the passage and query embedding models with a language model that supports your language\n",
|
|
"and also provide a dataset in your language."
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Here are the variables to specify our training data, the models that we use to initialize DPR\n",
|
|
"# and the directory where we'll be saving the model\n",
|
|
"\n",
|
|
"doc_dir = \"data/dpr_training/\"\n",
|
|
"\n",
|
|
"train_filename = \"train/biencoder-nq-train.json\"\n",
|
|
"dev_filename = \"dev/biencoder-nq-dev.json\"\n",
|
|
"\n",
|
|
"query_model = \"bert-base-uncased\"\n",
|
|
"passage_model = \"bert-base-uncased\"\n",
|
|
"\n",
|
|
"save_dir = \"../saved_models/dpr\""
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Option 2: Finetuning DPR\n",
|
|
"\n",
|
|
"If you have your own domain specific question answering or information retrieval dataset,\n",
|
|
"you might instead be interested in finetuning a pretrained DPR model.\n",
|
|
"In this case, you would initialize both query and passage models using the original pretrained model.\n",
|
|
"You will want to load something like this set of variables instead of the ones above"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Here are the variables you might want to use instead of the set above\n",
|
|
"# in order to perform pretraining\n",
|
|
"\n",
|
|
"doc_dir = \"PATH_TO_YOUR_DATA_DIR\"\n",
|
|
"train_filename = \"TRAIN_FILENAME\"\n",
|
|
"dev_filename = \"DEV_FILENAME\"\n",
|
|
"\n",
|
|
"query_model = \"facebook/dpr-question_encoder-single-nq-base\"\n",
|
|
"passage_model = \"facebook/dpr-ctx_encoder-single-nq-base\"\n",
|
|
"\n",
|
|
"save_dir = \"../saved_models/dpr\""
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Initialization\n",
|
|
"\n",
|
|
"Here we want to initialize our model either with plain language model weights for training from scratch\n",
|
|
"or else with pretrained DPR weights for finetuning.\n",
|
|
"We follow the [original DPR parameters](https://github.com/facebookresearch/DPR#best-hyperparameter-settings)\n",
|
|
"for their max passage length but set max query length to 64 since queries are very rarely longer."
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"## Initialize DPR model\n",
|
|
"\n",
|
|
"retriever = DensePassageRetriever(\n",
|
|
" document_store=InMemoryDocumentStore(),\n",
|
|
" query_embedding_model=query_model,\n",
|
|
" passage_embedding_model=passage_model,\n",
|
|
" max_seq_len_query=64,\n",
|
|
" max_seq_len_passage=256\n",
|
|
")"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Training\n",
|
|
"\n",
|
|
"Let's start training and save our trained model!\n",
|
|
"\n",
|
|
"On a V100 GPU, you can fit up to batch size 16 so we set gradient accumulation steps to 8 in order\n",
|
|
"to simulate the batch size 128 of the original DPR experiment.\n",
|
|
"\n",
|
|
"When `embed_title=True`, the document title is prepended to the input text sequence with a `[SEP]` token\n",
|
|
"between it and document text."
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"When training from scratch with the above variables, 1 epoch takes around an hour and we reached the following performance:\n",
|
|
"\n",
|
|
"```\n",
|
|
"loss: 0.046580662854042276\n",
|
|
"task_name: text_similarity\n",
|
|
"acc: 0.992524064068483\n",
|
|
"f1: 0.8804297774366846\n",
|
|
"acc_and_f1: 0.9364769207525838\n",
|
|
"average_rank: 0.19631619339984652\n",
|
|
"report:\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
"hard_negative 0.9961 0.9961 0.9961 201887\n",
|
|
" positive 0.8804 0.8804 0.8804 6515\n",
|
|
"\n",
|
|
" accuracy 0.9925 208402\n",
|
|
" macro avg 0.9383 0.9383 0.9383 208402\n",
|
|
" weighted avg 0.9925 0.9925 0.9925 208402\n",
|
|
"\n",
|
|
"```"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Start training our model and save it when it is finished\n",
|
|
"\n",
|
|
"retriever.train(\n",
|
|
" data_dir=doc_dir,\n",
|
|
" train_filename=train_filename,\n",
|
|
" dev_filename=dev_filename,\n",
|
|
" test_filename=dev_filename,\n",
|
|
" n_epochs=1,\n",
|
|
" batch_size=16,\n",
|
|
" grad_acc_steps=8,\n",
|
|
" save_dir=save_dir,\n",
|
|
" evaluate_every=3000,\n",
|
|
" embed_title=True,\n",
|
|
" num_positives=1,\n",
|
|
" num_hard_negatives=1\n",
|
|
")"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## Loading\n",
|
|
"\n",
|
|
"Loading our newly trained model is simple!"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"reloaded_retriever = DensePassageRetriever.load(load_dir=save_dir, document_store=None)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"## About us\n",
|
|
"\n",
|
|
"This [Haystack](https://github.com/deepset-ai/haystack/) notebook was made with love by [deepset](https://deepset.ai/) in Berlin, Germany\n",
|
|
"\n",
|
|
"We bring NLP to the industry via open source! \n",
|
|
"Our focus: Industry specific language models & large scale QA systems. \n",
|
|
" \n",
|
|
"Some of our other work: \n",
|
|
"- [German BERT](https://deepset.ai/german-bert)\n",
|
|
"- [GermanQuAD and GermanDPR](https://deepset.ai/germanquad)\n",
|
|
"- [FARM](https://github.com/deepset-ai/FARM)\n",
|
|
"\n",
|
|
"Get in touch:\n",
|
|
"[Twitter](https://twitter.com/deepset_ai) | [LinkedIn](https://www.linkedin.com/company/deepset-ai/) | [Slack](https://haystack.deepset.ai/community/join) | [GitHub Discussions](https://github.com/deepset-ai/haystack/discussions) | [Website](https://deepset.ai)\n",
|
|
"\n",
|
|
"By the way: [we're hiring!](https://apply.workable.com/deepset/) "
|
|
],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
},
|
|
"pycharm": {
|
|
"stem_cell": {
|
|
"cell_type": "raw",
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false
|
|
}
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
} |