DensePassageRetriever: Add Training, Refactor Inference to FARM modules (#527)

* dpr training and inference code refactored with FARM modules

* dpr test cases modified

* docstring and default arguments updated

* dpr training docstring updated

* bugfix in dense retriever inference, DPR tutorials modified

* Bump FARM to 0.5.0

* update README for DPR

* dpr training and inference code refactored with FARM modules

* dpr test cases modified

* docstring and default arguments updated

* dpr training docstring updated

* bugfix in dense retriever inference, DPR tutorials modified

* Bump FARM to 0.5.0

* update README for DPR

* mypy errors fix

* DPR instantiation bugfix

* Fix DPR init in RAG Tutorial

Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
kolk 2020-10-30 23:52:06 +05:30 committed by GitHub
parent f13443054a
commit 72b637ae6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 411 additions and 374 deletions

View File

@ -159,8 +159,11 @@ Example
.. code-block:: python
retriever = DensePassageRetriever(document_store=document_store,
embedding_model="dpr-bert-base-nq",
do_lower_case=True, use_gpu=True)
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True,
batch_size=16,
embed_title=True)
retriever.retrieve(query="Why did the revenue increase?")
# returns: [Document, Document]

View File

@ -2,16 +2,29 @@ import logging
from typing import List, Union, Tuple, Optional
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from farm.infer import Inferencer
from haystack.document_store.base import BaseDocumentStore
from haystack import Document
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.retriever.base import BaseRetriever
from haystack.retriever.sparse import logger
from transformers.modeling_dpr import DPRContextEncoder, DPRQuestionEncoder
from transformers.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
from farm.infer import Inferencer
from farm.modeling.tokenization import Tokenizer
from farm.modeling.language_model import LanguageModel
from farm.modeling.biadaptive_model import BiAdaptiveModel
from farm.modeling.prediction_head import TextSimilarityHead
from farm.data_handler.processor import TextSimilarityProcessor
from farm.data_handler.data_silo import DataSilo
from farm.data_handler.dataloader import NamedDataLoader
from farm.modeling.optimization import initialize_optimizer
from farm.train import Trainer
from torch.utils.data.sampler import SequentialSampler
logger = logging.getLogger(__name__)
@ -28,11 +41,13 @@ class DensePassageRetriever(BaseRetriever):
document_store: BaseDocumentStore,
query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base",
max_seq_len: int = 256,
max_seq_len_query: int = 64,
max_seq_len_passage: int = 256,
use_gpu: bool = True,
batch_size: int = 16,
embed_title: bool = True,
remove_sep_tok_from_untitled_passages: bool = True
use_fast_tokenizers: bool = True,
similarity_function: str = "dot_product"
):
"""
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@ -45,7 +60,8 @@ class DensePassageRetriever(BaseRetriever):
:param passage_embedding_model: Local path or remote name of passage encoder checkpoint. The format equals the
one used by hugging-face transformers' modelhub models
Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
:param max_seq_len: Longest length of each sequence
:param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
:param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
:param use_gpu: Whether to use gpu or not
:param batch_size: Number of questions or passages to encode at once
:param embed_title: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
@ -54,14 +70,12 @@ class DensePassageRetriever(BaseRetriever):
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
before writing them to the DocumentStore like this:
{"text": "my text", "meta": {"name": "my title"}}.
:param remove_sep_tok_from_untitled_passages: If embed_title is ``True``, there are different strategies to deal with documents that don't have a title.
If this param is ``True`` => Embed passage as single text, similar to embed_title = False (i.e [CLS] passage_tok1 ... [SEP]).
If this param is ``False`` => Embed passage as text pair with empty title (i.e. [CLS] [SEP] passage_tok1 ... [SEP])
"""
self.document_store = document_store
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.max_seq_len_passage = max_seq_len_passage
self.max_seq_len_query = max_seq_len_query
if use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")
@ -69,14 +83,39 @@ class DensePassageRetriever(BaseRetriever):
self.device = torch.device("cpu")
self.embed_title = embed_title
self.remove_sep_tok_from_untitled_passages = remove_sep_tok_from_untitled_passages
# Init & Load Encoders
self.query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(query_embedding_model)
self.query_encoder = DPRQuestionEncoder.from_pretrained(query_embedding_model).to(self.device)
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
do_lower_case=True, use_fast=use_fast_tokenizers)
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
language_model_class="DPRQuestionEncoder")
self.passage_tokenizer = DPRContextEncoderTokenizer.from_pretrained(passage_embedding_model)
self.passage_encoder = DPRContextEncoder.from_pretrained(passage_embedding_model).to(self.device)
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
do_lower_case=True, use_fast=use_fast_tokenizers)
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
language_model_class="DPRContextEncoder")
self.processor = TextSimilarityProcessor(tokenizer=self.query_tokenizer,
passage_tokenizer=self.passage_tokenizer,
max_seq_len_passage=self.max_seq_len_passage,
max_seq_len_query=self.max_seq_len_query,
label_list=["hard_negative", "positive"],
metric="text_similarity_metric",
embed_title=self.embed_title,
num_hard_negatives=0,
num_negatives=0)
prediction_head = TextSimilarityHead(similarity_function=similarity_function)
self.model = BiAdaptiveModel(
language_model1=self.query_encoder,
language_model2=self.passage_encoder,
prediction_heads=[prediction_head],
embeds_dropout_prob=0.1,
lm1_output_types=["per_sequence"],
lm2_output_types=["per_sequence"],
device=self.device,
)
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
if index is None:
@ -85,6 +124,48 @@ class DensePassageRetriever(BaseRetriever):
documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index)
return documents
def _get_predictions(self, dicts, tokenizer):
"""
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
:param dicts: list of dictionaries
examples:[{'query': "where is florida?"}, {'query': "who wrote lord of the rings?"}, ...]
[{'passages': [{
"title": 'Big Little Lies (TV series)',
"text": 'series garnered several accolades. It received..',
"label": 'positive',
"external_id": '18768923'},
{"title": 'Framlingham Castle',
"text": 'Castle on the Hill "Castle on the Hill" is a song by English..',
"label": 'positive',
"external_id": '19930582'}, ...]
:return: dictionary of embeddings for "passages" and "query"
"""
dataset, tensor_names, baskets = self.processor.dataset_from_dicts(
dicts, indices=[i for i in range(len(dicts))], return_baskets=True
)
data_loader = NamedDataLoader(
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
)
all_embeddings = {"query": torch.tensor([]).to(self.device), "passages": torch.tensor([]).to(self.device)}
self.model.eval()
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)):
batch = {key: batch[key].to(self.device) for key in batch}
# get logits
with torch.no_grad():
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
all_embeddings["query"] = torch.cat((all_embeddings["query"], query_embeddings), dim=0) \
if isinstance(query_embeddings, torch.Tensor) else None
all_embeddings["passages"] = torch.cat((all_embeddings["passages"], passage_embeddings), dim=0) \
if isinstance(passage_embeddings, torch.Tensor) else None
# convert embeddings to numpy array
for k, v in all_embeddings.items():
all_embeddings[k] = v.cpu().numpy() if v!=None else None
return all_embeddings
def embed_queries(self, texts: List[str]) -> List[np.array]:
"""
Create embeddings for a list of queries using the query encoder
@ -92,10 +173,8 @@ class DensePassageRetriever(BaseRetriever):
:param texts: Queries to embed
:return: Embeddings, one per input queries
"""
queries = [self._normalize_query(q) for q in texts]
result = self._generate_batch_predictions(texts=queries, model=self.query_encoder,
tokenizer=self.query_tokenizer,
batch_size=self.batch_size)
queries = [{'query': q} for q in texts]
result = self._get_predictions(queries, self.query_tokenizer)["query"]
return result
def embed_passages(self, docs: List[Document]) -> List[np.array]:
@ -105,158 +184,114 @@ class DensePassageRetriever(BaseRetriever):
:param docs: List of Document objects used to represent documents / passages in a standardized way within Haystack.
:return: Embeddings of documents / passages shape (batch_size, embedding_dim)
"""
texts = [d.text for d in docs]
titles = None
if self.embed_title:
titles = [d.meta["name"] if d.meta and "name" in d.meta else "" for d in docs]
passages = [{'passages': [{
"title": d.meta["name"] if d.meta and "name" in d.meta else "",
"text": d.text,
"label": d.meta["label"] if d.meta and "label" in d.meta else "positive",
"external_id": d.id}]
} for d in docs]
embeddings = self._get_predictions(passages, self.passage_tokenizer)["passages"]
result = self._generate_batch_predictions(texts=texts, titles=titles,
model=self.passage_encoder,
tokenizer=self.passage_tokenizer,
batch_size=self.batch_size)
return result
return embeddings
def _normalize_query(self, query: str) -> str:
if query[-1] == '?':
query = query[:-1]
return query
def _tensorizer(self, tokenizer: Union[DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer],
text: List[str],
title: Optional[List[str]] = None,
add_special_tokens: bool = True):
def train(self,
data_dir: str,
train_filename: str,
dev_filename: str = None,
test_filename: str = None,
batch_size: int = 2,
embed_title: bool = True,
num_hard_negatives: int = 1,
num_negatives: int = 0,
n_epochs: int = 3,
evaluate_every: int = 1000,
n_gpu: int = 1,
learning_rate: float = 1e-5,
epsilon: float = 1e-08,
weight_decay: float = 0.0,
num_warmup_steps: int = 100,
grad_acc_steps: int = 1,
optimizer_name: str = "TransformersAdamW",
optimizer_correct_bias: bool = True,
save_dir: str = "../saved_models/dpr-tutorial",
query_encoder_save_dir: str = "lm1",
passage_encoder_save_dir: str = "lm2"
):
"""
Creates tensors from text sequences
:Example:
>>> ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained()
>>> dpr_object._tensorizer(tokenizer=ctx_tokenizer, text=passages, title=titles)
:param tokenizer: An instance of DPRQuestionEncoderTokenizer or DPRContextEncoderTokenizer.
:param text: list of text sequences to be tokenized
:param title: optional list of titles associated with each text sequence
:param add_special_tokens: boolean for whether to encode special tokens in each sequence
Returns:
token_ids: list of token ids from vocabulary
token_type_ids: list of token type ids
attention_mask: list of indices specifying which tokens should be attended to by the encoder
train a DensePassageRetrieval model
:param data_dir: Directory where training file, dev file and test file are present
:param train_filename: training filename
:param dev_filename: development set filename, file to be used by model in eval step of training
:param test_filename: test set filename, file to be used by model in test step after training
:param batch_size: total number of samples in 1 batch of data
:param embed_title: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
:param num_hard_negatives: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer
:param num_negatives: number of negative passages(any random passage from dataset which do not contain answer to query)
:param n_epochs: number of epochs to train the model on
:param evaluate_every: number of training steps after evaluation is run
:param n_gpu: number of gpus to train on
:param learning_rate: learning rate of optimizer
:param epsilon: epsilon parameter of optimizer
:param weight_decay: weight decay parameter of optimizer
:param grad_acc_steps: number of steps to accumulate gradient over before back-propagation is done
:param optimizer_name: what optimizer to use (default: TransformersAdamW)
:param num_warmup_steps: number of warmup steps
:param optimizer_correct_bias: Whether to correct bias in optimizer
:param save_dir: directory where models are saved
:param query_encoder_save_dir: directory inside save_dir where query_encoder model files are saved
:param passage_encoder_save_dir: directory inside save_dir where passage_encoder model files are saved
"""
# combine titles with passages only if some titles are present with passages
if self.embed_title and title:
final_text = [tuple((title_, text_)) for title_, text_ in zip(title, text)] #type: Union[List[Tuple[str, ...]], List[str]]
else:
final_text = text
out = tokenizer.batch_encode_plus(final_text, add_special_tokens=add_special_tokens, truncation=True,
max_length=self.max_seq_len,
pad_to_max_length=True)
self.embed_title = embed_title
self.processor = TextSimilarityProcessor(tokenizer=self.query_tokenizer,
passage_tokenizer=self.passage_tokenizer,
max_seq_len_passage=self.max_seq_len_passage,
max_seq_len_query=self.max_seq_len_query,
label_list=["hard_negative", "positive"],
metric="text_similarity_metric",
data_dir=data_dir,
train_filename=train_filename,
dev_filename=dev_filename,
test_filename=test_filename,
embed_title=self.embed_title,
num_hard_negatives=num_hard_negatives,
num_negatives=num_negatives)
token_ids = torch.tensor(out['input_ids']).to(self.device)
token_type_ids = torch.tensor(out['token_type_ids']).to(self.device)
attention_mask = torch.tensor(out['attention_mask']).to(self.device)
return token_ids, token_type_ids, attention_mask
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
def _remove_sep_tok_from_untitled_passages(self, titles, ctx_ids_batch, ctx_attn_mask):
"""
removes [SEP] token from untitled samples in batch. For batches which has some untitled passages, remove [SEP]
token used to segment titles and passage from untitled samples in the batch
(Official DPR code do not encode [SEP] tokens in untitled passages)
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False)
:Example:
# Encoding passages with 'embed_title' = True. 1st passage is titled, 2nd passage is untitled
>>> texts = ['Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions.',
'Democratic Republic of the Congo to the south. Angola\'s capital, Luanda, lies on the Atlantic coast in the northwest of the country.'
]
>> titles = ["0", '']
>>> token_ids, token_type_ids, attention_mask = self._tensorizer(self.passage_tokenizer, text=texts, title=titles)
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[0]]
['[CLS]', '0', '[SEP]', 'aaron', 'aaron', '(', 'or', ';', ....]
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[1]]
['[CLS]', '[SEP]', 'democratic', 'republic', 'of', 'the', ....]
>>> new_ids, new_attn = self._remove_sep_tok_from_untitled_passages(titles, token_ids, attention_mask)
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[0]]
['[CLS]', '0', '[SEP]', 'aaron', 'aaron', '(', 'or', ';', ....]
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[1]]
['[CLS]', 'democratic', 'republic', 'of', 'the', 'congo', ...]
# 5. Create an optimizer
self.model, optimizer, lr_schedule = initialize_optimizer(
model=self.model,
learning_rate=learning_rate,
optimizer_opts={"name": optimizer_name, "correct_bias": optimizer_correct_bias,
"weight_decay": weight_decay, "eps": epsilon},
schedule_opts={"name": "LinearWarmup", "num_warmup_steps": num_warmup_steps},
n_batches=len(data_silo.loaders["train"]),
n_epochs=n_epochs,
grad_acc_steps=grad_acc_steps,
device=self.device
)
:param titles: list of titles for each sample
:param ctx_ids_batch: tensor of shape (batch_size, max_seq_len) containing token indices
:param ctx_attn_mask: tensor of shape (batch_size, max_seq_len) containing attention mask
# 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
trainer = Trainer(
model=self.model,
optimizer=optimizer,
data_silo=data_silo,
epochs=n_epochs,
n_gpu=n_gpu,
lr_schedule=lr_schedule,
evaluate_every=evaluate_every,
device=self.device,
)
Returns:
ctx_ids_batch: tensor of shape (batch_size, max_seq_len) containing token indices with [SEP] token removed
ctx_attn_mask: tensor of shape (batch_size, max_seq_len) reflecting the ctx_ids_batch changes
"""
# Skip [SEP] removal if passage encoder not bert model
if self.passage_encoder.ctx_encoder.base_model_prefix != 'bert_model':
logger.warning("Context encoder is not a BERT model. Skipping removal of [SEP] tokens")
return ctx_ids_batch, ctx_attn_mask
# 7. Let it grow! Watch the tracked metrics live on the public mlflow server: https://public-mlflow.deepset.ai
trainer.train()
# create a mask for titles in the batch
titles_mask = torch.tensor(list(map(lambda x: 0 if x == "" else 1, titles))).to(self.device)
self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
self.processor.save(Path(save_dir))
# get all untitled passage indices
no_title_indices = torch.nonzero(1 - titles_mask).squeeze(-1)
# remove [SEP] token index for untitled passages and add 1 pad to compensate
ctx_ids_batch[no_title_indices] = torch.cat((ctx_ids_batch[no_title_indices, 0].unsqueeze(-1),
ctx_ids_batch[no_title_indices, 2:],
torch.tensor([self.passage_tokenizer.pad_token_id]).expand(len(no_title_indices)).unsqueeze(-1).to(self.device)),
dim=1)
# Modify attention mask to reflect [SEP] token removal and pad addition in ctx_ids_batch
ctx_attn_mask[no_title_indices] = torch.cat((ctx_attn_mask[no_title_indices, 0].unsqueeze(-1),
ctx_attn_mask[no_title_indices, 2:],
torch.tensor([self.passage_tokenizer.pad_token_id]).expand(len(no_title_indices)).unsqueeze(-1).to(self.device)),
dim=1)
return ctx_ids_batch, ctx_attn_mask
def _generate_batch_predictions(self,
texts: List[str],
model: torch.nn.Module,
tokenizer: Union[DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer],
titles: Optional[List[str]] = None, #useful only for passage embedding with DPR!
batch_size: int = 16) -> List[Tuple[object, np.array]]:
n = len(texts)
total = 0
results = []
for batch_start in range(0, n, batch_size):
# create batch of titles only for passages
ctx_title = None
if self.embed_title and titles:
ctx_title = titles[batch_start:batch_start + batch_size]
# create batch of text
ctx_text = texts[batch_start:batch_start + batch_size]
# tensorize the batch
ctx_ids_batch, _, ctx_attn_mask = self._tensorizer(tokenizer, text=ctx_text, title=ctx_title)
ctx_seg_batch = torch.zeros_like(ctx_ids_batch).to(self.device)
# remove [SEP] token from untitled passages in batch
if self.embed_title and self.remove_sep_tok_from_untitled_passages and ctx_title:
ctx_ids_batch, ctx_attn_mask = self._remove_sep_tok_from_untitled_passages(ctx_title,
ctx_ids_batch,
ctx_attn_mask)
with torch.no_grad():
out = model(input_ids=ctx_ids_batch, attention_mask=ctx_attn_mask, token_type_ids=ctx_seg_batch)
# TODO revert back to when updating transformers
# out = out.pooler_output
out = out[0]
out = out.cpu()
total += ctx_ids_batch.size()[0]
results.extend([
(out[i].view(-1).numpy())
for i in range(out.size(0))
])
if total % 10 == 0:
logger.info(f'Embedded {total} / {n} texts')
return results
class EmbeddingRetriever(BaseRetriever):
def __init__(

View File

@ -21,4 +21,4 @@ tika
uvloop; sys_platform != 'win32' and sys_platform != 'cygwin'
httptools
nltk
more_itertools
more_itertools

View File

@ -180,7 +180,7 @@ def dpr_retriever(faiss_document_store):
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
embed_title=True,
remove_sep_tok_from_untitled_passages=True
use_fast_tokenizers=True
)
@ -288,14 +288,10 @@ def get_document_store(document_store_type, faiss_document_store, inmemory_docum
def get_retriever(retriever_type, document_store):
if retriever_type == "dpr":
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
embed_title=True,
remove_sep_tok_from_untitled_passages=True
)
retriever = DensePassageRetriever(document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False, embed_title=True)
elif retriever_type == "tfidf":
return TfidfRetriever(document_store=document_store)
elif retriever_type == "embedding":

View File

@ -45,11 +45,11 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
# FAISSDocumentStore doesn't return embeddings, so these tests only work with ElasticsearchDocumentStore
if isinstance(document_store, ElasticsearchDocumentStore):
assert (len(docs_with_emb[0].embedding) == 768)
assert (abs(docs_with_emb[0].embedding[0] - (-0.30634)) < 0.001)
assert (abs(docs_with_emb[1].embedding[0] - (-0.37449)) < 0.001)
assert (abs(docs_with_emb[2].embedding[0] - (-0.24695)) < 0.001)
assert (abs(docs_with_emb[3].embedding[0] - (-0.08017)) < 0.001)
assert (abs(docs_with_emb[4].embedding[0] - (-0.01534)) < 0.001)
assert (abs(docs_with_emb[0].embedding[0] - (-0.3063)) < 0.001)
assert (abs(docs_with_emb[1].embedding[0] - (-0.3914)) < 0.001)
assert (abs(docs_with_emb[2].embedding[0] - (-0.2470)) < 0.001)
assert (abs(docs_with_emb[3].embedding[0] - (-0.0802)) < 0.001)
assert (abs(docs_with_emb[4].embedding[0] - (-0.0551)) < 0.001)
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?")

View File

@ -515,11 +515,12 @@
"retriever = DensePassageRetriever(document_store=document_store,\n",
" query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
" passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
" max_seq_len_query=64,\n",
" max_seq_len_passage=256,\n",
" batch_size=16,\n",
" use_gpu=True,\n",
" embed_title=True,\n",
" max_seq_len=256,\n",
" batch_size=16,\n",
" remove_sep_tok_from_untitled_passages=True)\n",
" use_fast_tokenizers=True)\n",
"# Important: \n",
"# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all\n",
"# previously indexed documents and update their embedding representation. \n",

View File

@ -29,9 +29,13 @@ document_store.write_documents(dicts)
retriever = DensePassageRetriever(document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
max_seq_len_query=64,
max_seq_len_passage=256,
batch_size=2,
use_gpu=True,
embed_title=True,
remove_sep_tok_from_untitled_passages=True)
use_fast_tokenizers=True
)
# Important:
# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all

View File

@ -1,196 +1,195 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Tutorial7_RAG_Generator.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Tutorial7_RAG_Generator.ipynb",
"provenance": [],
"collapsed_sections": []
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "iDyfhfyp7Sjh"
},
"source": [
"!pip install git+https://github.com/deepset-ai/haystack.git\n",
"!pip install urllib3==1.25.4"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ICZanGLa7khF"
},
"source": [
"from typing import List\n",
"import requests\n",
"import pandas as pd\n",
"from haystack import Document\n",
"from haystack.document_store.faiss import FAISSDocumentStore\n",
"from haystack.generator.transformers import RAGenerator\n",
"from haystack.retriever.dense import DensePassageRetriever"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "D3f-CQ4c7lEN"
},
"source": [
"# Add documents from which you want generate answers\n",
"# Download a csv containing some sample documents data\n",
"# Here some sample documents data\n",
"temp = requests.get(\"https://raw.githubusercontent.com/deepset-ai/haystack/master/tutorials/small_generator_dataset.csv\")\n",
"open('small_generator_dataset.csv', 'wb').write(temp.content)\n",
"\n",
"# Get dataframe with columns \"title\", and \"text\"\n",
"df = pd.read_csv(\"small_generator_dataset.csv\", sep=',')\n",
"# Minimal cleaning\n",
"df.fillna(value=\"\", inplace=True)\n",
"\n",
"print(df.head())\n",
"\n",
"# Create to haystack document format\n",
"titles = list(df[\"title\"].values)\n",
"texts = list(df[\"text\"].values)\n",
"\n",
"documents: List[Document] = []\n",
"for title, text in zip(titles, texts):\n",
" documents.append(\n",
" Document(\n",
" text=text,\n",
" meta={\n",
" \"name\": title or \"\"\n",
" }\n",
" )\n",
" )"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "upRu3ebX7nr_"
},
"source": [
"# Initialize FAISS document store to documents and corresponding index for embeddings\n",
"# Set `return_embedding` to `True`, so generator doesn't have to perform re-embedding\n",
"document_store = FAISSDocumentStore(\n",
" faiss_index_factory_str=\"Flat\",\n",
" return_embedding=True\n",
")\n",
"\n",
"# Initialize DPR Retriever to encode documents, encode question and query documents\n",
"retriever = DensePassageRetriever(\n",
" document_store=document_store,\n",
" query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
" passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
" use_gpu=False,\n",
" embed_title=True,\n",
" remove_sep_tok_from_untitled_passages=True,\n",
")\n",
"\n",
"# Initialize RAG Generator\n",
"generator = RAGenerator(\n",
" model_name_or_path=\"facebook/rag-token-nq\",\n",
" use_gpu=False,\n",
" top_k_answers=1,\n",
" max_length=200,\n",
" min_length=2,\n",
" embed_title=True,\n",
" num_beams=2,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "as8j7hkW7rOW"
},
"source": [
"# Delete existing documents in documents store\n",
"document_store.delete_all_documents()\n",
"# Write documents to document store\n",
"document_store.write_documents(documents)\n",
"# Add documents embeddings to index\n",
"document_store.update_embeddings(\n",
" retriever=retriever\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "j8It45R872vb",
"cellView": "form"
},
"source": [
"#@title\n",
"# Now ask your questions\n",
"# We have some sample questions\n",
"QUESTIONS = [\n",
" \"who got the first nobel prize in physics\",\n",
" \"when is the next deadpool movie being released\",\n",
" \"which mode is used for short wave broadcast service\",\n",
" \"who is the owner of reading football club\",\n",
" \"when is the next scandal episode coming out\",\n",
" \"when is the last time the philadelphia won the superbowl\",\n",
" \"what is the most current adobe flash player version\",\n",
" \"how many episodes are there in dragon ball z\",\n",
" \"what is the first step in the evolution of the eye\",\n",
" \"where is gall bladder situated in human body\",\n",
" \"what is the main mineral in lithium batteries\",\n",
" \"who is the president of usa right now\",\n",
" \"where do the greasers live in the outsiders\",\n",
" \"panda is a national animal of which country\",\n",
" \"what is the name of manchester united stadium\",\n",
"]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xPUHRuTP742h"
},
"source": [
"# Now generate answer for question\n",
"for question in QUESTIONS:\n",
" # Retrieve related documents from retriever\n",
" retriever_results = retriever.retrieve(\n",
" query=question\n",
" )\n",
"\n",
" # Now generate answer from question and retrieved documents\n",
" predicted_result = generator.predict(\n",
" question=question,\n",
" documents=retriever_results,\n",
" top_k=1\n",
" )\n",
"\n",
" # Print you answer\n",
" answers = predicted_result[\"answers\"]\n",
" print(f'Generated answer is \\'{answers[0][\"answer\"]}\\' for the question = \\'{question}\\'')"
],
"execution_count": null,
"outputs": []
}
]
}
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "iDyfhfyp7Sjh"
},
"source": [
"!pip install git+https://github.com/deepset-ai/haystack.git\n",
"!pip install urllib3==1.25.4"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ICZanGLa7khF"
},
"source": [
"from typing import List\n",
"import requests\n",
"import pandas as pd\n",
"from haystack import Document\n",
"from haystack.document_store.faiss import FAISSDocumentStore\n",
"from haystack.generator.transformers import RAGenerator\n",
"from haystack.retriever.dense import DensePassageRetriever"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "D3f-CQ4c7lEN"
},
"source": [
"# Add documents from which you want generate answers\n",
"# Download a csv containing some sample documents data\n",
"# Here some sample documents data\n",
"temp = requests.get(\"https://raw.githubusercontent.com/deepset-ai/haystack/master/tutorials/small_generator_dataset.csv\")\n",
"open('small_generator_dataset.csv', 'wb').write(temp.content)\n",
"\n",
"# Get dataframe with columns \"title\", and \"text\"\n",
"df = pd.read_csv(\"small_generator_dataset.csv\", sep=',')\n",
"# Minimal cleaning\n",
"df.fillna(value=\"\", inplace=True)\n",
"\n",
"print(df.head())\n",
"\n",
"# Create to haystack document format\n",
"titles = list(df[\"title\"].values)\n",
"texts = list(df[\"text\"].values)\n",
"\n",
"documents: List[Document] = []\n",
"for title, text in zip(titles, texts):\n",
" documents.append(\n",
" Document(\n",
" text=text,\n",
" meta={\n",
" \"name\": title or \"\"\n",
" }\n",
" )\n",
" )"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "upRu3ebX7nr_"
},
"source": [
"# Initialize FAISS document store to documents and corresponding index for embeddings\n",
"# Set `return_embedding` to `True`, so generator doesn't have to perform re-embedding\n",
"document_store = FAISSDocumentStore(\n",
" faiss_index_factory_str=\"Flat\",\n",
" return_embedding=True\n",
")\n",
"\n",
"# Initialize DPR Retriever to encode documents, encode question and query documents\n",
"retriever = DensePassageRetriever(\n",
" document_store=document_store,\n",
" query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
" passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
" use_gpu=False,\n",
" embed_title=True,\n",
")\n",
"\n",
"# Initialize RAG Generator\n",
"generator = RAGenerator(\n",
" model_name_or_path=\"facebook/rag-token-nq\",\n",
" use_gpu=False,\n",
" top_k_answers=1,\n",
" max_length=200,\n",
" min_length=2,\n",
" embed_title=True,\n",
" num_beams=2,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "as8j7hkW7rOW"
},
"source": [
"# Delete existing documents in documents store\n",
"document_store.delete_all_documents()\n",
"# Write documents to document store\n",
"document_store.write_documents(documents)\n",
"# Add documents embeddings to index\n",
"document_store.update_embeddings(\n",
" retriever=retriever\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "j8It45R872vb",
"cellView": "form"
},
"source": [
"#@title\n",
"# Now ask your questions\n",
"# We have some sample questions\n",
"QUESTIONS = [\n",
" \"who got the first nobel prize in physics\",\n",
" \"when is the next deadpool movie being released\",\n",
" \"which mode is used for short wave broadcast service\",\n",
" \"who is the owner of reading football club\",\n",
" \"when is the next scandal episode coming out\",\n",
" \"when is the last time the philadelphia won the superbowl\",\n",
" \"what is the most current adobe flash player version\",\n",
" \"how many episodes are there in dragon ball z\",\n",
" \"what is the first step in the evolution of the eye\",\n",
" \"where is gall bladder situated in human body\",\n",
" \"what is the main mineral in lithium batteries\",\n",
" \"who is the president of usa right now\",\n",
" \"where do the greasers live in the outsiders\",\n",
" \"panda is a national animal of which country\",\n",
" \"what is the name of manchester united stadium\",\n",
"]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xPUHRuTP742h"
},
"source": [
"# Now generate answer for question\n",
"for question in QUESTIONS:\n",
" # Retrieve related documents from retriever\n",
" retriever_results = retriever.retrieve(\n",
" query=question\n",
" )\n",
"\n",
" # Now generate answer from question and retrieved documents\n",
" predicted_result = generator.predict(\n",
" question=question,\n",
" documents=retriever_results,\n",
" top_k=1\n",
" )\n",
"\n",
" # Print you answer\n",
" answers = predicted_result[\"answers\"]\n",
" print(f'Generated answer is \\'{answers[0][\"answer\"]}\\' for the question = \\'{question}\\'')"
],
"execution_count": null,
"outputs": []
}
]
}

View File

@ -50,7 +50,6 @@ retriever = DensePassageRetriever(
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False,
embed_title=True,
remove_sep_tok_from_untitled_passages=True,
)
# Initialize RAG Generator