mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-24 22:38:41 +00:00 
			
		
		
		
	 f954f0db38
			
		
	
	
		f954f0db38
		
			
		
	
	
	
	
		
			
			* Fix top_k param * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
		
			
				
	
	
		
			376 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			376 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | |
|  "cells": [
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "# Generative QA with \"Retrieval-Augmented Generation\"\n",
 | |
|     "\n",
 | |
|     "[](https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial7_RAG_Generator.ipynb)\n",
 | |
|     "\n",
 | |
|     "While extractive QA highlights the span of text that answers a query,\n",
 | |
|     "generative QA can return a novel text answer that it has composed.\n",
 | |
|     "In this tutorial, you will learn how to set up a generative system using the\n",
 | |
|     "[RAG model](https://arxiv.org/abs/2005.11401) which conditions the\n",
 | |
|     "answer generator on a set of retrieved documents."
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "### Prepare environment\n",
 | |
|     "\n",
 | |
|     "#### Colab: Enable the GPU runtime\n",
 | |
|     "Make sure you enable the GPU runtime to experience decent speed in this tutorial.\n",
 | |
|     "**Runtime -> Change Runtime type -> Hardware accelerator -> GPU**\n",
 | |
|     "\n",
 | |
|     "<img src=\"https://raw.githubusercontent.com/deepset-ai/haystack/master/docs/_src/img/colab_gpu_runtime.jpg\">"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Make sure you have a GPU running\n",
 | |
|     "!nvidia-smi"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "Here are the packages and imports that we'll need:"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "!pip install git+https://github.com/deepset-ai/haystack.git\n",
 | |
|     "!pip install urllib3==1.25.4"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "from typing import List\n",
 | |
|     "import requests\n",
 | |
|     "import pandas as pd\n",
 | |
|     "from haystack import Document\n",
 | |
|     "from haystack.document_store.faiss import FAISSDocumentStore\n",
 | |
|     "from haystack.generator.transformers import RAGenerator\n",
 | |
|     "from haystack.retriever.dense import DensePassageRetriever"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "Let's download a csv containing some sample text and preprocess the data.\n"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Download sample\n",
 | |
|     "temp = requests.get(\"https://raw.githubusercontent.com/deepset-ai/haystack/master/tutorials/small_generator_dataset.csv\")\n",
 | |
|     "open('small_generator_dataset.csv', 'wb').write(temp.content)\n",
 | |
|     "\n",
 | |
|     "# Create dataframe with columns \"title\" and \"text\"\n",
 | |
|     "df = pd.read_csv(\"small_generator_dataset.csv\", sep=',')\n",
 | |
|     "# Minimal cleaning\n",
 | |
|     "df.fillna(value=\"\", inplace=True)\n",
 | |
|     "\n",
 | |
|     "print(df.head())"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "We can cast our data into Haystack Document objects.\n",
 | |
|     "Alternatively, we can also just use dictionaries with \"text\" and \"meta\" fields"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Use data to initialize Document objects\n",
 | |
|     "titles = list(df[\"title\"].values)\n",
 | |
|     "texts = list(df[\"text\"].values)\n",
 | |
|     "documents: List[Document] = []\n",
 | |
|     "for title, text in zip(titles, texts):\n",
 | |
|     "    documents.append(\n",
 | |
|     "        Document(\n",
 | |
|     "            text=text,\n",
 | |
|     "            meta={\n",
 | |
|     "                \"name\": title or \"\"\n",
 | |
|     "            }\n",
 | |
|     "        )\n",
 | |
|     "    )"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "Here we initialize the FAISSDocumentStore, DensePassageRetriever and RAGenerator.\n",
 | |
|     "FAISS is chosen here since it is optimized vector storage."
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Initialize FAISS document store.\n",
 | |
|     "# Set `return_embedding` to `True`, so generator doesn't have to perform re-embedding\n",
 | |
|     "document_store = FAISSDocumentStore(\n",
 | |
|     "    faiss_index_factory_str=\"Flat\",\n",
 | |
|     "    return_embedding=True\n",
 | |
|     ")\n",
 | |
|     "\n",
 | |
|     "# Initialize DPR Retriever to encode documents, encode question and query documents\n",
 | |
|     "retriever = DensePassageRetriever(\n",
 | |
|     "    document_store=document_store,\n",
 | |
|     "    query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
 | |
|     "    passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
 | |
|     "    use_gpu=True,\n",
 | |
|     "    embed_title=True,\n",
 | |
|     ")\n",
 | |
|     "\n",
 | |
|     "# Initialize RAG Generator\n",
 | |
|     "generator = RAGenerator(\n",
 | |
|     "    model_name_or_path=\"facebook/rag-token-nq\",\n",
 | |
|     "    use_gpu=True,\n",
 | |
|     "    top_k=1,\n",
 | |
|     "    max_length=200,\n",
 | |
|     "    min_length=2,\n",
 | |
|     "    embed_title=True,\n",
 | |
|     "    num_beams=2,\n",
 | |
|     ")"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "We write documents to the DocumentStore, first by deleting any remaining documents then calling `write_documents()`.\n",
 | |
|     "The `update_embeddings()` method uses the retriever to create an embedding for each document.\n"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Delete existing documents in documents store\n",
 | |
|     "document_store.delete_all_documents()\n",
 | |
|     "\n",
 | |
|     "# Write documents to document store\n",
 | |
|     "document_store.write_documents(documents)\n",
 | |
|     "\n",
 | |
|     "# Add documents embeddings to index\n",
 | |
|     "document_store.update_embeddings(\n",
 | |
|     "    retriever=retriever\n",
 | |
|     ")"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "Here are our questions:"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "QUESTIONS = [\n",
 | |
|     "    \"who got the first nobel prize in physics\",\n",
 | |
|     "    \"when is the next deadpool movie being released\",\n",
 | |
|     "    \"which mode is used for short wave broadcast service\",\n",
 | |
|     "    \"who is the owner of reading football club\",\n",
 | |
|     "    \"when is the next scandal episode coming out\",\n",
 | |
|     "    \"when is the last time the philadelphia won the superbowl\",\n",
 | |
|     "    \"what is the most current adobe flash player version\",\n",
 | |
|     "    \"how many episodes are there in dragon ball z\",\n",
 | |
|     "    \"what is the first step in the evolution of the eye\",\n",
 | |
|     "    \"where is gall bladder situated in human body\",\n",
 | |
|     "    \"what is the main mineral in lithium batteries\",\n",
 | |
|     "    \"who is the president of usa right now\",\n",
 | |
|     "    \"where do the greasers live in the outsiders\",\n",
 | |
|     "    \"panda is a national animal of which country\",\n",
 | |
|     "    \"what is the name of manchester united stadium\",\n",
 | |
|     "]"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "source": [
 | |
|     "Now let's run our system!\n",
 | |
|     "The retriever will pick out a small subset of documents that it finds relevant.\n",
 | |
|     "These are used to condition the generator as it generates the answer.\n",
 | |
|     "What it should return then are novel text spans that form and answer to your question!"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Now generate an answer for each question\n",
 | |
|     "for question in QUESTIONS:\n",
 | |
|     "    # Retrieve related documents from retriever\n",
 | |
|     "    retriever_results = retriever.retrieve(\n",
 | |
|     "        query=question\n",
 | |
|     "    )\n",
 | |
|     "\n",
 | |
|     "    # Now generate answer from question and retrieved documents\n",
 | |
|     "    predicted_result = generator.predict(\n",
 | |
|     "        query=question,\n",
 | |
|     "        documents=retriever_results,\n",
 | |
|     "        top_k=1\n",
 | |
|     "    )\n",
 | |
|     "\n",
 | |
|     "    # Print you answer\n",
 | |
|     "    answers = predicted_result[\"answers\"]\n",
 | |
|     "    print(f'Generated answer is \\'{answers[0][\"answer\"]}\\' for the question = \\'{question}\\'')"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "# Or alternatively use the Pipeline class\n",
 | |
|     "from haystack.pipeline import GenerativeQAPipeline\n",
 | |
|     "\n",
 | |
|     "pipe = GenerativeQAPipeline(generator=generator, retriever=retriever)\n",
 | |
|     "for question in QUESTIONS:\n",
 | |
|     "    res = pipe.run(query=question, top_k_generator=1, top_k_retriever=5)\n",
 | |
|     "    print(res)"
 | |
|    ],
 | |
|    "metadata": {
 | |
|     "collapsed": false,
 | |
|     "pycharm": {
 | |
|      "name": "#%%\n"
 | |
|     }
 | |
|    }
 | |
|   }
 | |
|  ],
 | |
|  "metadata": {
 | |
|   "kernelspec": {
 | |
|    "display_name": "Python 3",
 | |
|    "language": "python",
 | |
|    "name": "python3"
 | |
|   },
 | |
|   "language_info": {
 | |
|    "codemirror_mode": {
 | |
|     "name": "ipython",
 | |
|     "version": 2
 | |
|    },
 | |
|    "file_extension": ".py",
 | |
|    "mimetype": "text/x-python",
 | |
|    "name": "python",
 | |
|    "nbconvert_exporter": "python",
 | |
|    "pygments_lexer": "ipython2",
 | |
|    "version": "2.7.6"
 | |
|   },
 | |
|   "pycharm": {
 | |
|    "stem_cell": {
 | |
|     "cell_type": "raw",
 | |
|     "source": [],
 | |
|     "metadata": {
 | |
|      "collapsed": false
 | |
|     }
 | |
|    }
 | |
|   }
 | |
|  },
 | |
|  "nbformat": 4,
 | |
|  "nbformat_minor": 0
 | |
| } |