haystack/tutorials/Tutorial7_RAG_Generator.py
Malte Pietsch 4a6c9302b3
Redesign primitives - Document, Answer, Label (#1398)
* first draft / notes on new primitives

* wip label / feedback refactor

* rename doc.text -> doc.content. add doc.content_type

* add datatype for content

* remove faq_question_field from ES and weaviate. rename text_field -> content_field in docstores. update tutorials for content field

* update converters for . Add warning for empty

* renam label.question -> label.query. Allow sorting of Answers.

* WIP primitives

* update ui/reader for new Answer format

* Improve Label. First refactoring of MultiLabel. Adjust eval code

* fixed workflow conflict with introducing new one (#1472)

* Add latest docstring and tutorial changes

* make add_eval_data() work again

* fix reader formats. WIP fix _extract_docs_and_labels_from_dict

* fix test reader

* Add latest docstring and tutorial changes

* fix another test case for reader

* fix mypy in farm reader.eval()

* fix mypy in farm reader.eval()

* WIP ORM refactor

* Add latest docstring and tutorial changes

* fix mypy weaviate

* make label and multilabel dataclasses

* bump mypy env in CI to python 3.8

* WIP refactor Label ORM

* WIP refactor Label ORM

* simplify tests for individual doc stores

* WIP refactoring markers of tests

* test alternative approach for tests with existing parametrization

* WIP refactor ORMs

* fix skip logic of already parametrized tests

* fix weaviate behaviour in tests - not parametrizing it in our general test cases.

* Add latest docstring and tutorial changes

* fix some tests

* remove sql from document_store_types

* fix markers for generator and pipeline test

* remove inmemory marker

* remove unneeded elasticsearch markers

* add dataclasses-json dependency. adjust ORM to just store JSON repr

* ignore type as dataclasses_json seems to miss functionality here

* update readme and contributing.md

* update contributing

* adjust example

* fix duplicate doc handling for custom index

* Add latest docstring and tutorial changes

* fix some ORM issues. fix get_all_labels_aggregated.

* update drop flags where get_all_labels_aggregated() was used before

* Add latest docstring and tutorial changes

* add to_json(). add + fix tests

* fix no_answer handling in label / multilabel

* fix duplicate docs in memory doc store. change primary key for sql doc table

* fix mypy issues

* fix mypy issues

* haystack/retriever/base.py

* fix test_write_document_meta[elastic]

* fix test_elasticsearch_custom_fields

* fix test_labels[elastic]

* fix crawler

* fix converter

* fix docx converter

* fix preprocessor

* fix test_utils

* fix tfidf retriever. fix selection of docstore in tests with multiple fixtures / parameterizations

* Add latest docstring and tutorial changes

* fix crawler test. fix ocrconverter attribute

* fix test_elasticsearch_custom_query

* fix generator pipeline

* fix ocr converter

* fix ragenerator

* Add latest docstring and tutorial changes

* fix test_load_and_save_yaml for elasticsearch

* fixes for pipeline tests

* fix faq pipeline

* fix pipeline tests

* Add latest docstring and tutorial changes

* fix weaviate

* Add latest docstring and tutorial changes

* trigger CI

* satisfy mypy

* Add latest docstring and tutorial changes

* satisfy mypy

* Add latest docstring and tutorial changes

* trigger CI

* fix question generation test

* fix ray. fix Q-generation

* fix translator test

* satisfy mypy

* wip refactor feedback rest api

* fix rest api feedback endpoint

* fix doc classifier

* remove relation of Labels -> Docs in SQL ORM

* fix faiss/milvus tests

* fix doc classifier test

* fix eval test

* fixing eval issues

* Add latest docstring and tutorial changes

* fix mypy

* WIP replace dataclasses-json with manual serialization

* Add latest docstring and tutorial changes

* revert to dataclass-json serialization for now. remove debug prints.

* update docstrings

* fix extractor. fix Answer Span init

* fix api test

* keep meta data of answers in reader.run()

* fix meta handling

* adress review feedback

* Add latest docstring and tutorial changes

* make document=None for open domain labels

* add import

* fix print utils

* fix rest api

* adress review feedback

* Add latest docstring and tutorial changes

* fix mypy

Co-authored-by: Markus Paff <markuspaff.mp@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2021-10-13 14:23:23 +02:00

127 lines
4.4 KiB
Python

from typing import List
import requests
import pandas as pd
from haystack import Document
from haystack.document_store.faiss import FAISSDocumentStore
from haystack.generator.transformers import RAGenerator
from haystack.retriever.dense import DensePassageRetriever
def tutorial7_rag_generator():
# Add documents from which you want generate answers
# Download a csv containing some sample documents data
# Here some sample documents data
temp = requests.get("https://raw.githubusercontent.com/deepset-ai/haystack/master/tutorials/small_generator_dataset.csv")
open('small_generator_dataset.csv', 'wb').write(temp.content)
# Get dataframe with columns "title", and "text"
df = pd.read_csv("small_generator_dataset.csv", sep=',')
# Minimal cleaning
df.fillna(value="", inplace=True)
print(df.head())
titles = list(df["title"].values)
texts = list(df["text"].values)
# Create to haystack document format
documents: List[Document] = []
for title, text in zip(titles, texts):
documents.append(
Document(
content=text,
meta={
"name": title or ""
}
)
)
# Initialize FAISS document store to documents and corresponding index for embeddings
# Set `return_embedding` to `True`, so generator doesn't have to perform re-embedding
document_store = FAISSDocumentStore(
faiss_index_factory_str="Flat",
return_embedding=True
)
# Initialize DPR Retriever to encode documents, encode question and query documents
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True,
embed_title=True,
)
# Initialize RAG Generator
generator = RAGenerator(
model_name_or_path="facebook/rag-token-nq",
use_gpu=True,
top_k=1,
max_length=200,
min_length=2,
embed_title=True,
num_beams=2,
)
# Delete existing documents in documents store
document_store.delete_documents()
# Write documents to document store
document_store.write_documents(documents)
# Add documents embeddings to index
document_store.update_embeddings(
retriever=retriever
)
# Now ask your questions
# We have some sample questions
QUESTIONS = [
"who got the first nobel prize in physics",
"when is the next deadpool movie being released",
"which mode is used for short wave broadcast service",
"who is the owner of reading football club",
"when is the next scandal episode coming out",
"when is the last time the philadelphia won the superbowl",
"what is the most current adobe flash player version",
"how many episodes are there in dragon ball z",
"what is the first step in the evolution of the eye",
"where is gall bladder situated in human body",
"what is the main mineral in lithium batteries",
"who is the president of usa right now",
"where do the greasers live in the outsiders",
"panda is a national animal of which country",
"what is the name of manchester united stadium",
]
# Now generate answer for question
for question in QUESTIONS:
# Retrieve related documents from retriever
retriever_results = retriever.retrieve(
query=question
)
# Now generate answer from question and retrieved documents
predicted_result = generator.predict(
query=question,
documents=retriever_results,
top_k=1
)
# Print you answer
answers = predicted_result["answers"]
print(f'Generated answer is \'{answers[0]["answer"]}\' for the question = \'{question}\'')
# Or alternatively use the Pipeline class
from haystack.pipeline import GenerativeQAPipeline
pipe = GenerativeQAPipeline(generator=generator, retriever=retriever)
for question in QUESTIONS:
res = pipe.run(query=question, params={"Generator": {"top_k": 1}, "Retriever": {"top_k": 5}})
print(res)
if __name__ == "__main__":
tutorial7_rag_generator()
# This Haystack script was made with love by deepset in Berlin, Germany
# Haystack: https://github.com/deepset-ai/haystack
# deepset: https://deepset.ai/