{ "cells": [ { "cell_type": "code", "metadata": { "id": "iDyfhfyp7Sjh" }, "source": [ "!pip install git+https://github.com/deepset-ai/haystack.git\n", "!pip install urllib3==1.25.4\n", "!pip install torch==1.6.0+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ICZanGLa7khF" }, "source": [ "from typing import List\n", "import requests\n", "import pandas as pd\n", "from haystack import Document\n", "from haystack.document_store.faiss import FAISSDocumentStore\n", "from haystack.generator.transformers import RAGenerator\n", "from haystack.retriever.dense import DensePassageRetriever" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "D3f-CQ4c7lEN" }, "source": [ "# Add documents from which you want generate answers\n", "# Download a csv containing some sample documents data\n", "# Here some sample documents data\n", "temp = requests.get(\"https://raw.githubusercontent.com/deepset-ai/haystack/master/tutorials/small_generator_dataset.csv\")\n", "open('small_generator_dataset.csv', 'wb').write(temp.content)\n", "\n", "# Get dataframe with columns \"title\", and \"text\"\n", "df = pd.read_csv(\"small_generator_dataset.csv\", sep=',')\n", "# Minimal cleaning\n", "df.fillna(value=\"\", inplace=True)\n", "\n", "print(df.head())\n", "\n", "# Create to haystack document format\n", "titles = list(df[\"title\"].values)\n", "texts = list(df[\"text\"].values)\n", "\n", "documents: List[Document] = []\n", "for title, text in zip(titles, texts):\n", " documents.append(\n", " Document(\n", " text=text,\n", " meta={\n", " \"name\": title or \"\"\n", " }\n", " )\n", " )" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "upRu3ebX7nr_" }, "source": [ "# Initialize FAISS document store to documents and corresponding index for embeddings\n", "# Set `return_embedding` to `True`, so generator doesn't have to perform re-embedding\n", "document_store = FAISSDocumentStore(\n", " faiss_index_factory_str=\"Flat\",\n", " return_embedding=True\n", ")\n", "\n", "# Initialize DPR Retriever to encode documents, encode question and query documents\n", "retriever = DensePassageRetriever(\n", " document_store=document_store,\n", " query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n", " passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n", " use_gpu=False,\n", " embed_title=True,\n", ")\n", "\n", "# Initialize RAG Generator\n", "generator = RAGenerator(\n", " model_name_or_path=\"facebook/rag-token-nq\",\n", " use_gpu=False,\n", " top_k_answers=1,\n", " max_length=200,\n", " min_length=2,\n", " embed_title=True,\n", " num_beams=2,\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "as8j7hkW7rOW" }, "source": [ "# Delete existing documents in documents store\n", "document_store.delete_all_documents()\n", "# Write documents to document store\n", "document_store.write_documents(documents)\n", "# Add documents embeddings to index\n", "document_store.update_embeddings(\n", " retriever=retriever\n", ")" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "j8It45R872vb", "cellView": "form" }, "source": [ "#@title\n", "# Now ask your questions\n", "# We have some sample questions\n", "QUESTIONS = [\n", " \"who got the first nobel prize in physics\",\n", " \"when is the next deadpool movie being released\",\n", " \"which mode is used for short wave broadcast service\",\n", " \"who is the owner of reading football club\",\n", " \"when is the next scandal episode coming out\",\n", " \"when is the last time the philadelphia won the superbowl\",\n", " \"what is the most current adobe flash player version\",\n", " \"how many episodes are there in dragon ball z\",\n", " \"what is the first step in the evolution of the eye\",\n", " \"where is gall bladder situated in human body\",\n", " \"what is the main mineral in lithium batteries\",\n", " \"who is the president of usa right now\",\n", " \"where do the greasers live in the outsiders\",\n", " \"panda is a national animal of which country\",\n", " \"what is the name of manchester united stadium\",\n", "]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "xPUHRuTP742h" }, "source": [ "# Now generate answer for question\n", "for question in QUESTIONS:\n", " # Retrieve related documents from retriever\n", " retriever_results = retriever.retrieve(\n", " query=question\n", " )\n", "\n", " # Now generate answer from question and retrieved documents\n", " predicted_result = generator.predict(\n", " question=question,\n", " documents=retriever_results,\n", " top_k=1\n", " )\n", "\n", " # Print you answer\n", " answers = predicted_result[\"answers\"]\n", " print(f'Generated answer is \\'{answers[0][\"answer\"]}\\' for the question = \\'{question}\\'')" ], "execution_count": null, "outputs": [] } ], "metadata": { "colab": { "name": "Tutorial7_RAG_Generator.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "nbformat": 4, "nbformat_minor": 0 }