mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
DensePassageRetriever: Add Training, Refactor Inference to FARM modules (#527)
* dpr training and inference code refactored with FARM modules * dpr test cases modified * docstring and default arguments updated * dpr training docstring updated * bugfix in dense retriever inference, DPR tutorials modified * Bump FARM to 0.5.0 * update README for DPR * dpr training and inference code refactored with FARM modules * dpr test cases modified * docstring and default arguments updated * dpr training docstring updated * bugfix in dense retriever inference, DPR tutorials modified * Bump FARM to 0.5.0 * update README for DPR * mypy errors fix * DPR instantiation bugfix * Fix DPR init in RAG Tutorial Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
f13443054a
commit
72b637ae6d
@ -159,8 +159,11 @@ Example
|
||||
.. code-block:: python
|
||||
|
||||
retriever = DensePassageRetriever(document_store=document_store,
|
||||
embedding_model="dpr-bert-base-nq",
|
||||
do_lower_case=True, use_gpu=True)
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=True,
|
||||
batch_size=16,
|
||||
embed_title=True)
|
||||
retriever.retrieve(query="Why did the revenue increase?")
|
||||
# returns: [Document, Document]
|
||||
|
||||
|
||||
@ -2,16 +2,29 @@ import logging
|
||||
from typing import List, Union, Tuple, Optional
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
from farm.infer import Inferencer
|
||||
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack import Document
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
from haystack.retriever.sparse import logger
|
||||
|
||||
from transformers.modeling_dpr import DPRContextEncoder, DPRQuestionEncoder
|
||||
from transformers.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
|
||||
from farm.infer import Inferencer
|
||||
from farm.modeling.tokenization import Tokenizer
|
||||
from farm.modeling.language_model import LanguageModel
|
||||
from farm.modeling.biadaptive_model import BiAdaptiveModel
|
||||
from farm.modeling.prediction_head import TextSimilarityHead
|
||||
from farm.data_handler.processor import TextSimilarityProcessor
|
||||
from farm.data_handler.data_silo import DataSilo
|
||||
from farm.data_handler.dataloader import NamedDataLoader
|
||||
from farm.modeling.optimization import initialize_optimizer
|
||||
from farm.train import Trainer
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -28,11 +41,13 @@ class DensePassageRetriever(BaseRetriever):
|
||||
document_store: BaseDocumentStore,
|
||||
query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base",
|
||||
max_seq_len: int = 256,
|
||||
max_seq_len_query: int = 64,
|
||||
max_seq_len_passage: int = 256,
|
||||
use_gpu: bool = True,
|
||||
batch_size: int = 16,
|
||||
embed_title: bool = True,
|
||||
remove_sep_tok_from_untitled_passages: bool = True
|
||||
use_fast_tokenizers: bool = True,
|
||||
similarity_function: str = "dot_product"
|
||||
):
|
||||
"""
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
@ -45,7 +60,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
:param passage_embedding_model: Local path or remote name of passage encoder checkpoint. The format equals the
|
||||
one used by hugging-face transformers' modelhub models
|
||||
Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
|
||||
:param max_seq_len: Longest length of each sequence
|
||||
:param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
|
||||
:param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
|
||||
:param use_gpu: Whether to use gpu or not
|
||||
:param batch_size: Number of questions or passages to encode at once
|
||||
:param embed_title: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
|
||||
@ -54,14 +70,12 @@ class DensePassageRetriever(BaseRetriever):
|
||||
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
|
||||
before writing them to the DocumentStore like this:
|
||||
{"text": "my text", "meta": {"name": "my title"}}.
|
||||
:param remove_sep_tok_from_untitled_passages: If embed_title is ``True``, there are different strategies to deal with documents that don't have a title.
|
||||
If this param is ``True`` => Embed passage as single text, similar to embed_title = False (i.e [CLS] passage_tok1 ... [SEP]).
|
||||
If this param is ``False`` => Embed passage as text pair with empty title (i.e. [CLS] [SEP] passage_tok1 ... [SEP])
|
||||
"""
|
||||
|
||||
self.document_store = document_store
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_seq_len_passage = max_seq_len_passage
|
||||
self.max_seq_len_query = max_seq_len_query
|
||||
|
||||
if use_gpu and torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
@ -69,14 +83,39 @@ class DensePassageRetriever(BaseRetriever):
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.embed_title = embed_title
|
||||
self.remove_sep_tok_from_untitled_passages = remove_sep_tok_from_untitled_passages
|
||||
|
||||
# Init & Load Encoders
|
||||
self.query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(query_embedding_model)
|
||||
self.query_encoder = DPRQuestionEncoder.from_pretrained(query_embedding_model).to(self.device)
|
||||
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
|
||||
do_lower_case=True, use_fast=use_fast_tokenizers)
|
||||
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
|
||||
language_model_class="DPRQuestionEncoder")
|
||||
|
||||
self.passage_tokenizer = DPRContextEncoderTokenizer.from_pretrained(passage_embedding_model)
|
||||
self.passage_encoder = DPRContextEncoder.from_pretrained(passage_embedding_model).to(self.device)
|
||||
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||
do_lower_case=True, use_fast=use_fast_tokenizers)
|
||||
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||
language_model_class="DPRContextEncoder")
|
||||
|
||||
self.processor = TextSimilarityProcessor(tokenizer=self.query_tokenizer,
|
||||
passage_tokenizer=self.passage_tokenizer,
|
||||
max_seq_len_passage=self.max_seq_len_passage,
|
||||
max_seq_len_query=self.max_seq_len_query,
|
||||
label_list=["hard_negative", "positive"],
|
||||
metric="text_similarity_metric",
|
||||
embed_title=self.embed_title,
|
||||
num_hard_negatives=0,
|
||||
num_negatives=0)
|
||||
|
||||
prediction_head = TextSimilarityHead(similarity_function=similarity_function)
|
||||
self.model = BiAdaptiveModel(
|
||||
language_model1=self.query_encoder,
|
||||
language_model2=self.passage_encoder,
|
||||
prediction_heads=[prediction_head],
|
||||
embeds_dropout_prob=0.1,
|
||||
lm1_output_types=["per_sequence"],
|
||||
lm2_output_types=["per_sequence"],
|
||||
device=self.device,
|
||||
)
|
||||
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
|
||||
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
if index is None:
|
||||
@ -85,6 +124,48 @@ class DensePassageRetriever(BaseRetriever):
|
||||
documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index)
|
||||
return documents
|
||||
|
||||
def _get_predictions(self, dicts, tokenizer):
|
||||
"""
|
||||
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
|
||||
|
||||
:param dicts: list of dictionaries
|
||||
examples:[{'query': "where is florida?"}, {'query': "who wrote lord of the rings?"}, ...]
|
||||
[{'passages': [{
|
||||
"title": 'Big Little Lies (TV series)',
|
||||
"text": 'series garnered several accolades. It received..',
|
||||
"label": 'positive',
|
||||
"external_id": '18768923'},
|
||||
{"title": 'Framlingham Castle',
|
||||
"text": 'Castle on the Hill "Castle on the Hill" is a song by English..',
|
||||
"label": 'positive',
|
||||
"external_id": '19930582'}, ...]
|
||||
:return: dictionary of embeddings for "passages" and "query"
|
||||
"""
|
||||
dataset, tensor_names, baskets = self.processor.dataset_from_dicts(
|
||||
dicts, indices=[i for i in range(len(dicts))], return_baskets=True
|
||||
)
|
||||
|
||||
data_loader = NamedDataLoader(
|
||||
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
|
||||
)
|
||||
all_embeddings = {"query": torch.tensor([]).to(self.device), "passages": torch.tensor([]).to(self.device)}
|
||||
self.model.eval()
|
||||
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)):
|
||||
batch = {key: batch[key].to(self.device) for key in batch}
|
||||
|
||||
# get logits
|
||||
with torch.no_grad():
|
||||
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
|
||||
all_embeddings["query"] = torch.cat((all_embeddings["query"], query_embeddings), dim=0) \
|
||||
if isinstance(query_embeddings, torch.Tensor) else None
|
||||
all_embeddings["passages"] = torch.cat((all_embeddings["passages"], passage_embeddings), dim=0) \
|
||||
if isinstance(passage_embeddings, torch.Tensor) else None
|
||||
|
||||
# convert embeddings to numpy array
|
||||
for k, v in all_embeddings.items():
|
||||
all_embeddings[k] = v.cpu().numpy() if v!=None else None
|
||||
return all_embeddings
|
||||
|
||||
def embed_queries(self, texts: List[str]) -> List[np.array]:
|
||||
"""
|
||||
Create embeddings for a list of queries using the query encoder
|
||||
@ -92,10 +173,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
:param texts: Queries to embed
|
||||
:return: Embeddings, one per input queries
|
||||
"""
|
||||
queries = [self._normalize_query(q) for q in texts]
|
||||
result = self._generate_batch_predictions(texts=queries, model=self.query_encoder,
|
||||
tokenizer=self.query_tokenizer,
|
||||
batch_size=self.batch_size)
|
||||
queries = [{'query': q} for q in texts]
|
||||
result = self._get_predictions(queries, self.query_tokenizer)["query"]
|
||||
return result
|
||||
|
||||
def embed_passages(self, docs: List[Document]) -> List[np.array]:
|
||||
@ -105,158 +184,114 @@ class DensePassageRetriever(BaseRetriever):
|
||||
:param docs: List of Document objects used to represent documents / passages in a standardized way within Haystack.
|
||||
:return: Embeddings of documents / passages shape (batch_size, embedding_dim)
|
||||
"""
|
||||
texts = [d.text for d in docs]
|
||||
titles = None
|
||||
if self.embed_title:
|
||||
titles = [d.meta["name"] if d.meta and "name" in d.meta else "" for d in docs]
|
||||
passages = [{'passages': [{
|
||||
"title": d.meta["name"] if d.meta and "name" in d.meta else "",
|
||||
"text": d.text,
|
||||
"label": d.meta["label"] if d.meta and "label" in d.meta else "positive",
|
||||
"external_id": d.id}]
|
||||
} for d in docs]
|
||||
embeddings = self._get_predictions(passages, self.passage_tokenizer)["passages"]
|
||||
|
||||
result = self._generate_batch_predictions(texts=texts, titles=titles,
|
||||
model=self.passage_encoder,
|
||||
tokenizer=self.passage_tokenizer,
|
||||
batch_size=self.batch_size)
|
||||
return result
|
||||
return embeddings
|
||||
|
||||
def _normalize_query(self, query: str) -> str:
|
||||
if query[-1] == '?':
|
||||
query = query[:-1]
|
||||
return query
|
||||
|
||||
def _tensorizer(self, tokenizer: Union[DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer],
|
||||
text: List[str],
|
||||
title: Optional[List[str]] = None,
|
||||
add_special_tokens: bool = True):
|
||||
def train(self,
|
||||
data_dir: str,
|
||||
train_filename: str,
|
||||
dev_filename: str = None,
|
||||
test_filename: str = None,
|
||||
batch_size: int = 2,
|
||||
embed_title: bool = True,
|
||||
num_hard_negatives: int = 1,
|
||||
num_negatives: int = 0,
|
||||
n_epochs: int = 3,
|
||||
evaluate_every: int = 1000,
|
||||
n_gpu: int = 1,
|
||||
learning_rate: float = 1e-5,
|
||||
epsilon: float = 1e-08,
|
||||
weight_decay: float = 0.0,
|
||||
num_warmup_steps: int = 100,
|
||||
grad_acc_steps: int = 1,
|
||||
optimizer_name: str = "TransformersAdamW",
|
||||
optimizer_correct_bias: bool = True,
|
||||
save_dir: str = "../saved_models/dpr-tutorial",
|
||||
query_encoder_save_dir: str = "lm1",
|
||||
passage_encoder_save_dir: str = "lm2"
|
||||
):
|
||||
"""
|
||||
Creates tensors from text sequences
|
||||
:Example:
|
||||
>>> ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained()
|
||||
>>> dpr_object._tensorizer(tokenizer=ctx_tokenizer, text=passages, title=titles)
|
||||
|
||||
:param tokenizer: An instance of DPRQuestionEncoderTokenizer or DPRContextEncoderTokenizer.
|
||||
:param text: list of text sequences to be tokenized
|
||||
:param title: optional list of titles associated with each text sequence
|
||||
:param add_special_tokens: boolean for whether to encode special tokens in each sequence
|
||||
|
||||
Returns:
|
||||
token_ids: list of token ids from vocabulary
|
||||
token_type_ids: list of token type ids
|
||||
attention_mask: list of indices specifying which tokens should be attended to by the encoder
|
||||
train a DensePassageRetrieval model
|
||||
:param data_dir: Directory where training file, dev file and test file are present
|
||||
:param train_filename: training filename
|
||||
:param dev_filename: development set filename, file to be used by model in eval step of training
|
||||
:param test_filename: test set filename, file to be used by model in test step after training
|
||||
:param batch_size: total number of samples in 1 batch of data
|
||||
:param embed_title: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
|
||||
:param num_hard_negatives: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer
|
||||
:param num_negatives: number of negative passages(any random passage from dataset which do not contain answer to query)
|
||||
:param n_epochs: number of epochs to train the model on
|
||||
:param evaluate_every: number of training steps after evaluation is run
|
||||
:param n_gpu: number of gpus to train on
|
||||
:param learning_rate: learning rate of optimizer
|
||||
:param epsilon: epsilon parameter of optimizer
|
||||
:param weight_decay: weight decay parameter of optimizer
|
||||
:param grad_acc_steps: number of steps to accumulate gradient over before back-propagation is done
|
||||
:param optimizer_name: what optimizer to use (default: TransformersAdamW)
|
||||
:param num_warmup_steps: number of warmup steps
|
||||
:param optimizer_correct_bias: Whether to correct bias in optimizer
|
||||
:param save_dir: directory where models are saved
|
||||
:param query_encoder_save_dir: directory inside save_dir where query_encoder model files are saved
|
||||
:param passage_encoder_save_dir: directory inside save_dir where passage_encoder model files are saved
|
||||
"""
|
||||
|
||||
# combine titles with passages only if some titles are present with passages
|
||||
if self.embed_title and title:
|
||||
final_text = [tuple((title_, text_)) for title_, text_ in zip(title, text)] #type: Union[List[Tuple[str, ...]], List[str]]
|
||||
else:
|
||||
final_text = text
|
||||
out = tokenizer.batch_encode_plus(final_text, add_special_tokens=add_special_tokens, truncation=True,
|
||||
max_length=self.max_seq_len,
|
||||
pad_to_max_length=True)
|
||||
self.embed_title = embed_title
|
||||
self.processor = TextSimilarityProcessor(tokenizer=self.query_tokenizer,
|
||||
passage_tokenizer=self.passage_tokenizer,
|
||||
max_seq_len_passage=self.max_seq_len_passage,
|
||||
max_seq_len_query=self.max_seq_len_query,
|
||||
label_list=["hard_negative", "positive"],
|
||||
metric="text_similarity_metric",
|
||||
data_dir=data_dir,
|
||||
train_filename=train_filename,
|
||||
dev_filename=dev_filename,
|
||||
test_filename=test_filename,
|
||||
embed_title=self.embed_title,
|
||||
num_hard_negatives=num_hard_negatives,
|
||||
num_negatives=num_negatives)
|
||||
|
||||
token_ids = torch.tensor(out['input_ids']).to(self.device)
|
||||
token_type_ids = torch.tensor(out['token_type_ids']).to(self.device)
|
||||
attention_mask = torch.tensor(out['attention_mask']).to(self.device)
|
||||
return token_ids, token_type_ids, attention_mask
|
||||
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)
|
||||
|
||||
def _remove_sep_tok_from_untitled_passages(self, titles, ctx_ids_batch, ctx_attn_mask):
|
||||
"""
|
||||
removes [SEP] token from untitled samples in batch. For batches which has some untitled passages, remove [SEP]
|
||||
token used to segment titles and passage from untitled samples in the batch
|
||||
(Official DPR code do not encode [SEP] tokens in untitled passages)
|
||||
data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False)
|
||||
|
||||
:Example:
|
||||
# Encoding passages with 'embed_title' = True. 1st passage is titled, 2nd passage is untitled
|
||||
>>> texts = ['Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions.',
|
||||
'Democratic Republic of the Congo to the south. Angola\'s capital, Luanda, lies on the Atlantic coast in the northwest of the country.'
|
||||
]
|
||||
>> titles = ["0", '']
|
||||
>>> token_ids, token_type_ids, attention_mask = self._tensorizer(self.passage_tokenizer, text=texts, title=titles)
|
||||
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[0]]
|
||||
['[CLS]', '0', '[SEP]', 'aaron', 'aaron', '(', 'or', ';', ....]
|
||||
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[1]]
|
||||
['[CLS]', '[SEP]', 'democratic', 'republic', 'of', 'the', ....]
|
||||
>>> new_ids, new_attn = self._remove_sep_tok_from_untitled_passages(titles, token_ids, attention_mask)
|
||||
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[0]]
|
||||
['[CLS]', '0', '[SEP]', 'aaron', 'aaron', '(', 'or', ';', ....]
|
||||
>>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[1]]
|
||||
['[CLS]', 'democratic', 'republic', 'of', 'the', 'congo', ...]
|
||||
# 5. Create an optimizer
|
||||
self.model, optimizer, lr_schedule = initialize_optimizer(
|
||||
model=self.model,
|
||||
learning_rate=learning_rate,
|
||||
optimizer_opts={"name": optimizer_name, "correct_bias": optimizer_correct_bias,
|
||||
"weight_decay": weight_decay, "eps": epsilon},
|
||||
schedule_opts={"name": "LinearWarmup", "num_warmup_steps": num_warmup_steps},
|
||||
n_batches=len(data_silo.loaders["train"]),
|
||||
n_epochs=n_epochs,
|
||||
grad_acc_steps=grad_acc_steps,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
:param titles: list of titles for each sample
|
||||
:param ctx_ids_batch: tensor of shape (batch_size, max_seq_len) containing token indices
|
||||
:param ctx_attn_mask: tensor of shape (batch_size, max_seq_len) containing attention mask
|
||||
# 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
optimizer=optimizer,
|
||||
data_silo=data_silo,
|
||||
epochs=n_epochs,
|
||||
n_gpu=n_gpu,
|
||||
lr_schedule=lr_schedule,
|
||||
evaluate_every=evaluate_every,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
Returns:
|
||||
ctx_ids_batch: tensor of shape (batch_size, max_seq_len) containing token indices with [SEP] token removed
|
||||
ctx_attn_mask: tensor of shape (batch_size, max_seq_len) reflecting the ctx_ids_batch changes
|
||||
"""
|
||||
# Skip [SEP] removal if passage encoder not bert model
|
||||
if self.passage_encoder.ctx_encoder.base_model_prefix != 'bert_model':
|
||||
logger.warning("Context encoder is not a BERT model. Skipping removal of [SEP] tokens")
|
||||
return ctx_ids_batch, ctx_attn_mask
|
||||
# 7. Let it grow! Watch the tracked metrics live on the public mlflow server: https://public-mlflow.deepset.ai
|
||||
trainer.train()
|
||||
|
||||
# create a mask for titles in the batch
|
||||
titles_mask = torch.tensor(list(map(lambda x: 0 if x == "" else 1, titles))).to(self.device)
|
||||
self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
|
||||
self.processor.save(Path(save_dir))
|
||||
|
||||
# get all untitled passage indices
|
||||
no_title_indices = torch.nonzero(1 - titles_mask).squeeze(-1)
|
||||
|
||||
# remove [SEP] token index for untitled passages and add 1 pad to compensate
|
||||
ctx_ids_batch[no_title_indices] = torch.cat((ctx_ids_batch[no_title_indices, 0].unsqueeze(-1),
|
||||
ctx_ids_batch[no_title_indices, 2:],
|
||||
torch.tensor([self.passage_tokenizer.pad_token_id]).expand(len(no_title_indices)).unsqueeze(-1).to(self.device)),
|
||||
dim=1)
|
||||
# Modify attention mask to reflect [SEP] token removal and pad addition in ctx_ids_batch
|
||||
ctx_attn_mask[no_title_indices] = torch.cat((ctx_attn_mask[no_title_indices, 0].unsqueeze(-1),
|
||||
ctx_attn_mask[no_title_indices, 2:],
|
||||
torch.tensor([self.passage_tokenizer.pad_token_id]).expand(len(no_title_indices)).unsqueeze(-1).to(self.device)),
|
||||
dim=1)
|
||||
|
||||
return ctx_ids_batch, ctx_attn_mask
|
||||
|
||||
def _generate_batch_predictions(self,
|
||||
texts: List[str],
|
||||
model: torch.nn.Module,
|
||||
tokenizer: Union[DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer],
|
||||
titles: Optional[List[str]] = None, #useful only for passage embedding with DPR!
|
||||
batch_size: int = 16) -> List[Tuple[object, np.array]]:
|
||||
n = len(texts)
|
||||
total = 0
|
||||
results = []
|
||||
for batch_start in range(0, n, batch_size):
|
||||
# create batch of titles only for passages
|
||||
ctx_title = None
|
||||
if self.embed_title and titles:
|
||||
ctx_title = titles[batch_start:batch_start + batch_size]
|
||||
|
||||
# create batch of text
|
||||
ctx_text = texts[batch_start:batch_start + batch_size]
|
||||
|
||||
# tensorize the batch
|
||||
ctx_ids_batch, _, ctx_attn_mask = self._tensorizer(tokenizer, text=ctx_text, title=ctx_title)
|
||||
ctx_seg_batch = torch.zeros_like(ctx_ids_batch).to(self.device)
|
||||
|
||||
# remove [SEP] token from untitled passages in batch
|
||||
if self.embed_title and self.remove_sep_tok_from_untitled_passages and ctx_title:
|
||||
ctx_ids_batch, ctx_attn_mask = self._remove_sep_tok_from_untitled_passages(ctx_title,
|
||||
ctx_ids_batch,
|
||||
ctx_attn_mask)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(input_ids=ctx_ids_batch, attention_mask=ctx_attn_mask, token_type_ids=ctx_seg_batch)
|
||||
# TODO revert back to when updating transformers
|
||||
# out = out.pooler_output
|
||||
out = out[0]
|
||||
out = out.cpu()
|
||||
|
||||
total += ctx_ids_batch.size()[0]
|
||||
|
||||
results.extend([
|
||||
(out[i].view(-1).numpy())
|
||||
for i in range(out.size(0))
|
||||
])
|
||||
|
||||
if total % 10 == 0:
|
||||
logger.info(f'Embedded {total} / {n} texts')
|
||||
|
||||
return results
|
||||
|
||||
class EmbeddingRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
|
||||
@ -21,4 +21,4 @@ tika
|
||||
uvloop; sys_platform != 'win32' and sys_platform != 'cygwin'
|
||||
httptools
|
||||
nltk
|
||||
more_itertools
|
||||
more_itertools
|
||||
|
||||
@ -180,7 +180,7 @@ def dpr_retriever(faiss_document_store):
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=False,
|
||||
embed_title=True,
|
||||
remove_sep_tok_from_untitled_passages=True
|
||||
use_fast_tokenizers=True
|
||||
)
|
||||
|
||||
|
||||
@ -288,14 +288,10 @@ def get_document_store(document_store_type, faiss_document_store, inmemory_docum
|
||||
def get_retriever(retriever_type, document_store):
|
||||
|
||||
if retriever_type == "dpr":
|
||||
retriever = DensePassageRetriever(
|
||||
document_store=document_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=False,
|
||||
embed_title=True,
|
||||
remove_sep_tok_from_untitled_passages=True
|
||||
)
|
||||
retriever = DensePassageRetriever(document_store=document_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=False, embed_title=True)
|
||||
elif retriever_type == "tfidf":
|
||||
return TfidfRetriever(document_store=document_store)
|
||||
elif retriever_type == "embedding":
|
||||
|
||||
@ -45,11 +45,11 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
|
||||
# FAISSDocumentStore doesn't return embeddings, so these tests only work with ElasticsearchDocumentStore
|
||||
if isinstance(document_store, ElasticsearchDocumentStore):
|
||||
assert (len(docs_with_emb[0].embedding) == 768)
|
||||
assert (abs(docs_with_emb[0].embedding[0] - (-0.30634)) < 0.001)
|
||||
assert (abs(docs_with_emb[1].embedding[0] - (-0.37449)) < 0.001)
|
||||
assert (abs(docs_with_emb[2].embedding[0] - (-0.24695)) < 0.001)
|
||||
assert (abs(docs_with_emb[3].embedding[0] - (-0.08017)) < 0.001)
|
||||
assert (abs(docs_with_emb[4].embedding[0] - (-0.01534)) < 0.001)
|
||||
assert (abs(docs_with_emb[0].embedding[0] - (-0.3063)) < 0.001)
|
||||
assert (abs(docs_with_emb[1].embedding[0] - (-0.3914)) < 0.001)
|
||||
assert (abs(docs_with_emb[2].embedding[0] - (-0.2470)) < 0.001)
|
||||
assert (abs(docs_with_emb[3].embedding[0] - (-0.0802)) < 0.001)
|
||||
assert (abs(docs_with_emb[4].embedding[0] - (-0.0551)) < 0.001)
|
||||
|
||||
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?")
|
||||
|
||||
|
||||
@ -515,11 +515,12 @@
|
||||
"retriever = DensePassageRetriever(document_store=document_store,\n",
|
||||
" query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
|
||||
" passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
|
||||
" max_seq_len_query=64,\n",
|
||||
" max_seq_len_passage=256,\n",
|
||||
" batch_size=16,\n",
|
||||
" use_gpu=True,\n",
|
||||
" embed_title=True,\n",
|
||||
" max_seq_len=256,\n",
|
||||
" batch_size=16,\n",
|
||||
" remove_sep_tok_from_untitled_passages=True)\n",
|
||||
" use_fast_tokenizers=True)\n",
|
||||
"# Important: \n",
|
||||
"# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all\n",
|
||||
"# previously indexed documents and update their embedding representation. \n",
|
||||
|
||||
@ -29,9 +29,13 @@ document_store.write_documents(dicts)
|
||||
retriever = DensePassageRetriever(document_store=document_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
max_seq_len_query=64,
|
||||
max_seq_len_passage=256,
|
||||
batch_size=2,
|
||||
use_gpu=True,
|
||||
embed_title=True,
|
||||
remove_sep_tok_from_untitled_passages=True)
|
||||
use_fast_tokenizers=True
|
||||
)
|
||||
|
||||
# Important:
|
||||
# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all
|
||||
|
||||
@ -1,196 +1,195 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "Tutorial7_RAG_Generator.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
}
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "Tutorial7_RAG_Generator.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": []
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "iDyfhfyp7Sjh"
|
||||
},
|
||||
"source": [
|
||||
"!pip install git+https://github.com/deepset-ai/haystack.git\n",
|
||||
"!pip install urllib3==1.25.4"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "ICZanGLa7khF"
|
||||
},
|
||||
"source": [
|
||||
"from typing import List\n",
|
||||
"import requests\n",
|
||||
"import pandas as pd\n",
|
||||
"from haystack import Document\n",
|
||||
"from haystack.document_store.faiss import FAISSDocumentStore\n",
|
||||
"from haystack.generator.transformers import RAGenerator\n",
|
||||
"from haystack.retriever.dense import DensePassageRetriever"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "D3f-CQ4c7lEN"
|
||||
},
|
||||
"source": [
|
||||
"# Add documents from which you want generate answers\n",
|
||||
"# Download a csv containing some sample documents data\n",
|
||||
"# Here some sample documents data\n",
|
||||
"temp = requests.get(\"https://raw.githubusercontent.com/deepset-ai/haystack/master/tutorials/small_generator_dataset.csv\")\n",
|
||||
"open('small_generator_dataset.csv', 'wb').write(temp.content)\n",
|
||||
"\n",
|
||||
"# Get dataframe with columns \"title\", and \"text\"\n",
|
||||
"df = pd.read_csv(\"small_generator_dataset.csv\", sep=',')\n",
|
||||
"# Minimal cleaning\n",
|
||||
"df.fillna(value=\"\", inplace=True)\n",
|
||||
"\n",
|
||||
"print(df.head())\n",
|
||||
"\n",
|
||||
"# Create to haystack document format\n",
|
||||
"titles = list(df[\"title\"].values)\n",
|
||||
"texts = list(df[\"text\"].values)\n",
|
||||
"\n",
|
||||
"documents: List[Document] = []\n",
|
||||
"for title, text in zip(titles, texts):\n",
|
||||
" documents.append(\n",
|
||||
" Document(\n",
|
||||
" text=text,\n",
|
||||
" meta={\n",
|
||||
" \"name\": title or \"\"\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" )"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "upRu3ebX7nr_"
|
||||
},
|
||||
"source": [
|
||||
"# Initialize FAISS document store to documents and corresponding index for embeddings\n",
|
||||
"# Set `return_embedding` to `True`, so generator doesn't have to perform re-embedding\n",
|
||||
"document_store = FAISSDocumentStore(\n",
|
||||
" faiss_index_factory_str=\"Flat\",\n",
|
||||
" return_embedding=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Initialize DPR Retriever to encode documents, encode question and query documents\n",
|
||||
"retriever = DensePassageRetriever(\n",
|
||||
" document_store=document_store,\n",
|
||||
" query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
|
||||
" passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
|
||||
" use_gpu=False,\n",
|
||||
" embed_title=True,\n",
|
||||
" remove_sep_tok_from_untitled_passages=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Initialize RAG Generator\n",
|
||||
"generator = RAGenerator(\n",
|
||||
" model_name_or_path=\"facebook/rag-token-nq\",\n",
|
||||
" use_gpu=False,\n",
|
||||
" top_k_answers=1,\n",
|
||||
" max_length=200,\n",
|
||||
" min_length=2,\n",
|
||||
" embed_title=True,\n",
|
||||
" num_beams=2,\n",
|
||||
")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "as8j7hkW7rOW"
|
||||
},
|
||||
"source": [
|
||||
"# Delete existing documents in documents store\n",
|
||||
"document_store.delete_all_documents()\n",
|
||||
"# Write documents to document store\n",
|
||||
"document_store.write_documents(documents)\n",
|
||||
"# Add documents embeddings to index\n",
|
||||
"document_store.update_embeddings(\n",
|
||||
" retriever=retriever\n",
|
||||
")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "j8It45R872vb",
|
||||
"cellView": "form"
|
||||
},
|
||||
"source": [
|
||||
"#@title\n",
|
||||
"# Now ask your questions\n",
|
||||
"# We have some sample questions\n",
|
||||
"QUESTIONS = [\n",
|
||||
" \"who got the first nobel prize in physics\",\n",
|
||||
" \"when is the next deadpool movie being released\",\n",
|
||||
" \"which mode is used for short wave broadcast service\",\n",
|
||||
" \"who is the owner of reading football club\",\n",
|
||||
" \"when is the next scandal episode coming out\",\n",
|
||||
" \"when is the last time the philadelphia won the superbowl\",\n",
|
||||
" \"what is the most current adobe flash player version\",\n",
|
||||
" \"how many episodes are there in dragon ball z\",\n",
|
||||
" \"what is the first step in the evolution of the eye\",\n",
|
||||
" \"where is gall bladder situated in human body\",\n",
|
||||
" \"what is the main mineral in lithium batteries\",\n",
|
||||
" \"who is the president of usa right now\",\n",
|
||||
" \"where do the greasers live in the outsiders\",\n",
|
||||
" \"panda is a national animal of which country\",\n",
|
||||
" \"what is the name of manchester united stadium\",\n",
|
||||
"]"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "xPUHRuTP742h"
|
||||
},
|
||||
"source": [
|
||||
"# Now generate answer for question\n",
|
||||
"for question in QUESTIONS:\n",
|
||||
" # Retrieve related documents from retriever\n",
|
||||
" retriever_results = retriever.retrieve(\n",
|
||||
" query=question\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Now generate answer from question and retrieved documents\n",
|
||||
" predicted_result = generator.predict(\n",
|
||||
" question=question,\n",
|
||||
" documents=retriever_results,\n",
|
||||
" top_k=1\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Print you answer\n",
|
||||
" answers = predicted_result[\"answers\"]\n",
|
||||
" print(f'Generated answer is \\'{answers[0][\"answer\"]}\\' for the question = \\'{question}\\'')"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
}
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "iDyfhfyp7Sjh"
|
||||
},
|
||||
"source": [
|
||||
"!pip install git+https://github.com/deepset-ai/haystack.git\n",
|
||||
"!pip install urllib3==1.25.4"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "ICZanGLa7khF"
|
||||
},
|
||||
"source": [
|
||||
"from typing import List\n",
|
||||
"import requests\n",
|
||||
"import pandas as pd\n",
|
||||
"from haystack import Document\n",
|
||||
"from haystack.document_store.faiss import FAISSDocumentStore\n",
|
||||
"from haystack.generator.transformers import RAGenerator\n",
|
||||
"from haystack.retriever.dense import DensePassageRetriever"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "D3f-CQ4c7lEN"
|
||||
},
|
||||
"source": [
|
||||
"# Add documents from which you want generate answers\n",
|
||||
"# Download a csv containing some sample documents data\n",
|
||||
"# Here some sample documents data\n",
|
||||
"temp = requests.get(\"https://raw.githubusercontent.com/deepset-ai/haystack/master/tutorials/small_generator_dataset.csv\")\n",
|
||||
"open('small_generator_dataset.csv', 'wb').write(temp.content)\n",
|
||||
"\n",
|
||||
"# Get dataframe with columns \"title\", and \"text\"\n",
|
||||
"df = pd.read_csv(\"small_generator_dataset.csv\", sep=',')\n",
|
||||
"# Minimal cleaning\n",
|
||||
"df.fillna(value=\"\", inplace=True)\n",
|
||||
"\n",
|
||||
"print(df.head())\n",
|
||||
"\n",
|
||||
"# Create to haystack document format\n",
|
||||
"titles = list(df[\"title\"].values)\n",
|
||||
"texts = list(df[\"text\"].values)\n",
|
||||
"\n",
|
||||
"documents: List[Document] = []\n",
|
||||
"for title, text in zip(titles, texts):\n",
|
||||
" documents.append(\n",
|
||||
" Document(\n",
|
||||
" text=text,\n",
|
||||
" meta={\n",
|
||||
" \"name\": title or \"\"\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" )"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "upRu3ebX7nr_"
|
||||
},
|
||||
"source": [
|
||||
"# Initialize FAISS document store to documents and corresponding index for embeddings\n",
|
||||
"# Set `return_embedding` to `True`, so generator doesn't have to perform re-embedding\n",
|
||||
"document_store = FAISSDocumentStore(\n",
|
||||
" faiss_index_factory_str=\"Flat\",\n",
|
||||
" return_embedding=True\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Initialize DPR Retriever to encode documents, encode question and query documents\n",
|
||||
"retriever = DensePassageRetriever(\n",
|
||||
" document_store=document_store,\n",
|
||||
" query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
|
||||
" passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n",
|
||||
" use_gpu=False,\n",
|
||||
" embed_title=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Initialize RAG Generator\n",
|
||||
"generator = RAGenerator(\n",
|
||||
" model_name_or_path=\"facebook/rag-token-nq\",\n",
|
||||
" use_gpu=False,\n",
|
||||
" top_k_answers=1,\n",
|
||||
" max_length=200,\n",
|
||||
" min_length=2,\n",
|
||||
" embed_title=True,\n",
|
||||
" num_beams=2,\n",
|
||||
")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "as8j7hkW7rOW"
|
||||
},
|
||||
"source": [
|
||||
"# Delete existing documents in documents store\n",
|
||||
"document_store.delete_all_documents()\n",
|
||||
"# Write documents to document store\n",
|
||||
"document_store.write_documents(documents)\n",
|
||||
"# Add documents embeddings to index\n",
|
||||
"document_store.update_embeddings(\n",
|
||||
" retriever=retriever\n",
|
||||
")"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "j8It45R872vb",
|
||||
"cellView": "form"
|
||||
},
|
||||
"source": [
|
||||
"#@title\n",
|
||||
"# Now ask your questions\n",
|
||||
"# We have some sample questions\n",
|
||||
"QUESTIONS = [\n",
|
||||
" \"who got the first nobel prize in physics\",\n",
|
||||
" \"when is the next deadpool movie being released\",\n",
|
||||
" \"which mode is used for short wave broadcast service\",\n",
|
||||
" \"who is the owner of reading football club\",\n",
|
||||
" \"when is the next scandal episode coming out\",\n",
|
||||
" \"when is the last time the philadelphia won the superbowl\",\n",
|
||||
" \"what is the most current adobe flash player version\",\n",
|
||||
" \"how many episodes are there in dragon ball z\",\n",
|
||||
" \"what is the first step in the evolution of the eye\",\n",
|
||||
" \"where is gall bladder situated in human body\",\n",
|
||||
" \"what is the main mineral in lithium batteries\",\n",
|
||||
" \"who is the president of usa right now\",\n",
|
||||
" \"where do the greasers live in the outsiders\",\n",
|
||||
" \"panda is a national animal of which country\",\n",
|
||||
" \"what is the name of manchester united stadium\",\n",
|
||||
"]"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "xPUHRuTP742h"
|
||||
},
|
||||
"source": [
|
||||
"# Now generate answer for question\n",
|
||||
"for question in QUESTIONS:\n",
|
||||
" # Retrieve related documents from retriever\n",
|
||||
" retriever_results = retriever.retrieve(\n",
|
||||
" query=question\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Now generate answer from question and retrieved documents\n",
|
||||
" predicted_result = generator.predict(\n",
|
||||
" question=question,\n",
|
||||
" documents=retriever_results,\n",
|
||||
" top_k=1\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Print you answer\n",
|
||||
" answers = predicted_result[\"answers\"]\n",
|
||||
" print(f'Generated answer is \\'{answers[0][\"answer\"]}\\' for the question = \\'{question}\\'')"
|
||||
],
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -50,7 +50,6 @@ retriever = DensePassageRetriever(
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=False,
|
||||
embed_title=True,
|
||||
remove_sep_tok_from_untitled_passages=True,
|
||||
)
|
||||
|
||||
# Initialize RAG Generator
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user