diff --git a/README.rst b/README.rst index d0b390ffd..222050947 100644 --- a/README.rst +++ b/README.rst @@ -159,8 +159,11 @@ Example .. code-block:: python retriever = DensePassageRetriever(document_store=document_store, - embedding_model="dpr-bert-base-nq", - do_lower_case=True, use_gpu=True) + query_embedding_model="facebook/dpr-question_encoder-single-nq-base", + passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", + use_gpu=True, + batch_size=16, + embed_title=True) retriever.retrieve(query="Why did the revenue increase?") # returns: [Document, Document] diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index 7da146276..216277d95 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -2,16 +2,29 @@ import logging from typing import List, Union, Tuple, Optional import torch import numpy as np +from pathlib import Path +from tqdm import tqdm from farm.infer import Inferencer from haystack.document_store.base import BaseDocumentStore from haystack import Document +from haystack.document_store.elasticsearch import ElasticsearchDocumentStore from haystack.retriever.base import BaseRetriever from haystack.retriever.sparse import logger -from transformers.modeling_dpr import DPRContextEncoder, DPRQuestionEncoder -from transformers.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer +from farm.infer import Inferencer +from farm.modeling.tokenization import Tokenizer +from farm.modeling.language_model import LanguageModel +from farm.modeling.biadaptive_model import BiAdaptiveModel +from farm.modeling.prediction_head import TextSimilarityHead +from farm.data_handler.processor import TextSimilarityProcessor +from farm.data_handler.data_silo import DataSilo +from farm.data_handler.dataloader import NamedDataLoader +from farm.modeling.optimization import initialize_optimizer +from farm.train import Trainer +from torch.utils.data.sampler import SequentialSampler + logger = logging.getLogger(__name__) @@ -28,11 +41,13 @@ class DensePassageRetriever(BaseRetriever): document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", - max_seq_len: int = 256, + max_seq_len_query: int = 64, + max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, - remove_sep_tok_from_untitled_passages: bool = True + use_fast_tokenizers: bool = True, + similarity_function: str = "dot_product" ): """ Init the Retriever incl. the two encoder models from a local or remote model checkpoint. @@ -45,7 +60,8 @@ class DensePassageRetriever(BaseRetriever): :param passage_embedding_model: Local path or remote name of passage encoder checkpoint. The format equals the one used by hugging-face transformers' modelhub models Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"`` - :param max_seq_len: Longest length of each sequence + :param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down." + :param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down." :param use_gpu: Whether to use gpu or not :param batch_size: Number of questions or passages to encode at once :param embed_title: Whether to concatenate title and passage to a text pair that is then used to create the embedding. @@ -54,14 +70,12 @@ class DensePassageRetriever(BaseRetriever): The title is expected to be present in doc.meta["name"] and can be supplied in the documents before writing them to the DocumentStore like this: {"text": "my text", "meta": {"name": "my title"}}. - :param remove_sep_tok_from_untitled_passages: If embed_title is ``True``, there are different strategies to deal with documents that don't have a title. - If this param is ``True`` => Embed passage as single text, similar to embed_title = False (i.e [CLS] passage_tok1 ... [SEP]). - If this param is ``False`` => Embed passage as text pair with empty title (i.e. [CLS] [SEP] passage_tok1 ... [SEP]) """ self.document_store = document_store self.batch_size = batch_size - self.max_seq_len = max_seq_len + self.max_seq_len_passage = max_seq_len_passage + self.max_seq_len_query = max_seq_len_query if use_gpu and torch.cuda.is_available(): self.device = torch.device("cuda") @@ -69,14 +83,39 @@ class DensePassageRetriever(BaseRetriever): self.device = torch.device("cpu") self.embed_title = embed_title - self.remove_sep_tok_from_untitled_passages = remove_sep_tok_from_untitled_passages # Init & Load Encoders - self.query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(query_embedding_model) - self.query_encoder = DPRQuestionEncoder.from_pretrained(query_embedding_model).to(self.device) + self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model, + do_lower_case=True, use_fast=use_fast_tokenizers) + self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model, + language_model_class="DPRQuestionEncoder") - self.passage_tokenizer = DPRContextEncoderTokenizer.from_pretrained(passage_embedding_model) - self.passage_encoder = DPRContextEncoder.from_pretrained(passage_embedding_model).to(self.device) + self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model, + do_lower_case=True, use_fast=use_fast_tokenizers) + self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model, + language_model_class="DPRContextEncoder") + + self.processor = TextSimilarityProcessor(tokenizer=self.query_tokenizer, + passage_tokenizer=self.passage_tokenizer, + max_seq_len_passage=self.max_seq_len_passage, + max_seq_len_query=self.max_seq_len_query, + label_list=["hard_negative", "positive"], + metric="text_similarity_metric", + embed_title=self.embed_title, + num_hard_negatives=0, + num_negatives=0) + + prediction_head = TextSimilarityHead(similarity_function=similarity_function) + self.model = BiAdaptiveModel( + language_model1=self.query_encoder, + language_model2=self.passage_encoder, + prediction_heads=[prediction_head], + embeds_dropout_prob=0.1, + lm1_output_types=["per_sequence"], + lm2_output_types=["per_sequence"], + device=self.device, + ) + self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False) def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]: if index is None: @@ -85,6 +124,48 @@ class DensePassageRetriever(BaseRetriever): documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index) return documents + def _get_predictions(self, dicts, tokenizer): + """ + Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting). + + :param dicts: list of dictionaries + examples:[{'query': "where is florida?"}, {'query': "who wrote lord of the rings?"}, ...] + [{'passages': [{ + "title": 'Big Little Lies (TV series)', + "text": 'series garnered several accolades. It received..', + "label": 'positive', + "external_id": '18768923'}, + {"title": 'Framlingham Castle', + "text": 'Castle on the Hill "Castle on the Hill" is a song by English..', + "label": 'positive', + "external_id": '19930582'}, ...] + :return: dictionary of embeddings for "passages" and "query" + """ + dataset, tensor_names, baskets = self.processor.dataset_from_dicts( + dicts, indices=[i for i in range(len(dicts))], return_baskets=True + ) + + data_loader = NamedDataLoader( + dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names + ) + all_embeddings = {"query": torch.tensor([]).to(self.device), "passages": torch.tensor([]).to(self.device)} + self.model.eval() + for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing Samples", unit=" Batches", disable=False)): + batch = {key: batch[key].to(self.device) for key in batch} + + # get logits + with torch.no_grad(): + query_embeddings, passage_embeddings = self.model.forward(**batch)[0] + all_embeddings["query"] = torch.cat((all_embeddings["query"], query_embeddings), dim=0) \ + if isinstance(query_embeddings, torch.Tensor) else None + all_embeddings["passages"] = torch.cat((all_embeddings["passages"], passage_embeddings), dim=0) \ + if isinstance(passage_embeddings, torch.Tensor) else None + + # convert embeddings to numpy array + for k, v in all_embeddings.items(): + all_embeddings[k] = v.cpu().numpy() if v!=None else None + return all_embeddings + def embed_queries(self, texts: List[str]) -> List[np.array]: """ Create embeddings for a list of queries using the query encoder @@ -92,10 +173,8 @@ class DensePassageRetriever(BaseRetriever): :param texts: Queries to embed :return: Embeddings, one per input queries """ - queries = [self._normalize_query(q) for q in texts] - result = self._generate_batch_predictions(texts=queries, model=self.query_encoder, - tokenizer=self.query_tokenizer, - batch_size=self.batch_size) + queries = [{'query': q} for q in texts] + result = self._get_predictions(queries, self.query_tokenizer)["query"] return result def embed_passages(self, docs: List[Document]) -> List[np.array]: @@ -105,158 +184,114 @@ class DensePassageRetriever(BaseRetriever): :param docs: List of Document objects used to represent documents / passages in a standardized way within Haystack. :return: Embeddings of documents / passages shape (batch_size, embedding_dim) """ - texts = [d.text for d in docs] - titles = None - if self.embed_title: - titles = [d.meta["name"] if d.meta and "name" in d.meta else "" for d in docs] + passages = [{'passages': [{ + "title": d.meta["name"] if d.meta and "name" in d.meta else "", + "text": d.text, + "label": d.meta["label"] if d.meta and "label" in d.meta else "positive", + "external_id": d.id}] + } for d in docs] + embeddings = self._get_predictions(passages, self.passage_tokenizer)["passages"] - result = self._generate_batch_predictions(texts=texts, titles=titles, - model=self.passage_encoder, - tokenizer=self.passage_tokenizer, - batch_size=self.batch_size) - return result + return embeddings - def _normalize_query(self, query: str) -> str: - if query[-1] == '?': - query = query[:-1] - return query - - def _tensorizer(self, tokenizer: Union[DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer], - text: List[str], - title: Optional[List[str]] = None, - add_special_tokens: bool = True): + def train(self, + data_dir: str, + train_filename: str, + dev_filename: str = None, + test_filename: str = None, + batch_size: int = 2, + embed_title: bool = True, + num_hard_negatives: int = 1, + num_negatives: int = 0, + n_epochs: int = 3, + evaluate_every: int = 1000, + n_gpu: int = 1, + learning_rate: float = 1e-5, + epsilon: float = 1e-08, + weight_decay: float = 0.0, + num_warmup_steps: int = 100, + grad_acc_steps: int = 1, + optimizer_name: str = "TransformersAdamW", + optimizer_correct_bias: bool = True, + save_dir: str = "../saved_models/dpr-tutorial", + query_encoder_save_dir: str = "lm1", + passage_encoder_save_dir: str = "lm2" + ): """ - Creates tensors from text sequences - :Example: - >>> ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained() - >>> dpr_object._tensorizer(tokenizer=ctx_tokenizer, text=passages, title=titles) - - :param tokenizer: An instance of DPRQuestionEncoderTokenizer or DPRContextEncoderTokenizer. - :param text: list of text sequences to be tokenized - :param title: optional list of titles associated with each text sequence - :param add_special_tokens: boolean for whether to encode special tokens in each sequence - - Returns: - token_ids: list of token ids from vocabulary - token_type_ids: list of token type ids - attention_mask: list of indices specifying which tokens should be attended to by the encoder + train a DensePassageRetrieval model + :param data_dir: Directory where training file, dev file and test file are present + :param train_filename: training filename + :param dev_filename: development set filename, file to be used by model in eval step of training + :param test_filename: test set filename, file to be used by model in test step after training + :param batch_size: total number of samples in 1 batch of data + :param embed_title: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage + :param num_hard_negatives: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer + :param num_negatives: number of negative passages(any random passage from dataset which do not contain answer to query) + :param n_epochs: number of epochs to train the model on + :param evaluate_every: number of training steps after evaluation is run + :param n_gpu: number of gpus to train on + :param learning_rate: learning rate of optimizer + :param epsilon: epsilon parameter of optimizer + :param weight_decay: weight decay parameter of optimizer + :param grad_acc_steps: number of steps to accumulate gradient over before back-propagation is done + :param optimizer_name: what optimizer to use (default: TransformersAdamW) + :param num_warmup_steps: number of warmup steps + :param optimizer_correct_bias: Whether to correct bias in optimizer + :param save_dir: directory where models are saved + :param query_encoder_save_dir: directory inside save_dir where query_encoder model files are saved + :param passage_encoder_save_dir: directory inside save_dir where passage_encoder model files are saved """ - # combine titles with passages only if some titles are present with passages - if self.embed_title and title: - final_text = [tuple((title_, text_)) for title_, text_ in zip(title, text)] #type: Union[List[Tuple[str, ...]], List[str]] - else: - final_text = text - out = tokenizer.batch_encode_plus(final_text, add_special_tokens=add_special_tokens, truncation=True, - max_length=self.max_seq_len, - pad_to_max_length=True) + self.embed_title = embed_title + self.processor = TextSimilarityProcessor(tokenizer=self.query_tokenizer, + passage_tokenizer=self.passage_tokenizer, + max_seq_len_passage=self.max_seq_len_passage, + max_seq_len_query=self.max_seq_len_query, + label_list=["hard_negative", "positive"], + metric="text_similarity_metric", + data_dir=data_dir, + train_filename=train_filename, + dev_filename=dev_filename, + test_filename=test_filename, + embed_title=self.embed_title, + num_hard_negatives=num_hard_negatives, + num_negatives=num_negatives) - token_ids = torch.tensor(out['input_ids']).to(self.device) - token_type_ids = torch.tensor(out['token_type_ids']).to(self.device) - attention_mask = torch.tensor(out['attention_mask']).to(self.device) - return token_ids, token_type_ids, attention_mask + self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True) - def _remove_sep_tok_from_untitled_passages(self, titles, ctx_ids_batch, ctx_attn_mask): - """ - removes [SEP] token from untitled samples in batch. For batches which has some untitled passages, remove [SEP] - token used to segment titles and passage from untitled samples in the batch - (Official DPR code do not encode [SEP] tokens in untitled passages) + data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False) - :Example: - # Encoding passages with 'embed_title' = True. 1st passage is titled, 2nd passage is untitled - >>> texts = ['Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions.', - 'Democratic Republic of the Congo to the south. Angola\'s capital, Luanda, lies on the Atlantic coast in the northwest of the country.' - ] - >> titles = ["0", ''] - >>> token_ids, token_type_ids, attention_mask = self._tensorizer(self.passage_tokenizer, text=texts, title=titles) - >>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[0]] - ['[CLS]', '0', '[SEP]', 'aaron', 'aaron', '(', 'or', ';', ....] - >>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[1]] - ['[CLS]', '[SEP]', 'democratic', 'republic', 'of', 'the', ....] - >>> new_ids, new_attn = self._remove_sep_tok_from_untitled_passages(titles, token_ids, attention_mask) - >>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[0]] - ['[CLS]', '0', '[SEP]', 'aaron', 'aaron', '(', 'or', ';', ....] - >>> [self.passage_tokenizer.ids_to_tokens[tok.item()] for tok in token_ids[1]] - ['[CLS]', 'democratic', 'republic', 'of', 'the', 'congo', ...] + # 5. Create an optimizer + self.model, optimizer, lr_schedule = initialize_optimizer( + model=self.model, + learning_rate=learning_rate, + optimizer_opts={"name": optimizer_name, "correct_bias": optimizer_correct_bias, + "weight_decay": weight_decay, "eps": epsilon}, + schedule_opts={"name": "LinearWarmup", "num_warmup_steps": num_warmup_steps}, + n_batches=len(data_silo.loaders["train"]), + n_epochs=n_epochs, + grad_acc_steps=grad_acc_steps, + device=self.device + ) - :param titles: list of titles for each sample - :param ctx_ids_batch: tensor of shape (batch_size, max_seq_len) containing token indices - :param ctx_attn_mask: tensor of shape (batch_size, max_seq_len) containing attention mask + # 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time + trainer = Trainer( + model=self.model, + optimizer=optimizer, + data_silo=data_silo, + epochs=n_epochs, + n_gpu=n_gpu, + lr_schedule=lr_schedule, + evaluate_every=evaluate_every, + device=self.device, + ) - Returns: - ctx_ids_batch: tensor of shape (batch_size, max_seq_len) containing token indices with [SEP] token removed - ctx_attn_mask: tensor of shape (batch_size, max_seq_len) reflecting the ctx_ids_batch changes - """ - # Skip [SEP] removal if passage encoder not bert model - if self.passage_encoder.ctx_encoder.base_model_prefix != 'bert_model': - logger.warning("Context encoder is not a BERT model. Skipping removal of [SEP] tokens") - return ctx_ids_batch, ctx_attn_mask + # 7. Let it grow! Watch the tracked metrics live on the public mlflow server: https://public-mlflow.deepset.ai + trainer.train() - # create a mask for titles in the batch - titles_mask = torch.tensor(list(map(lambda x: 0 if x == "" else 1, titles))).to(self.device) + self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir) + self.processor.save(Path(save_dir)) - # get all untitled passage indices - no_title_indices = torch.nonzero(1 - titles_mask).squeeze(-1) - - # remove [SEP] token index for untitled passages and add 1 pad to compensate - ctx_ids_batch[no_title_indices] = torch.cat((ctx_ids_batch[no_title_indices, 0].unsqueeze(-1), - ctx_ids_batch[no_title_indices, 2:], - torch.tensor([self.passage_tokenizer.pad_token_id]).expand(len(no_title_indices)).unsqueeze(-1).to(self.device)), - dim=1) - # Modify attention mask to reflect [SEP] token removal and pad addition in ctx_ids_batch - ctx_attn_mask[no_title_indices] = torch.cat((ctx_attn_mask[no_title_indices, 0].unsqueeze(-1), - ctx_attn_mask[no_title_indices, 2:], - torch.tensor([self.passage_tokenizer.pad_token_id]).expand(len(no_title_indices)).unsqueeze(-1).to(self.device)), - dim=1) - - return ctx_ids_batch, ctx_attn_mask - - def _generate_batch_predictions(self, - texts: List[str], - model: torch.nn.Module, - tokenizer: Union[DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer], - titles: Optional[List[str]] = None, #useful only for passage embedding with DPR! - batch_size: int = 16) -> List[Tuple[object, np.array]]: - n = len(texts) - total = 0 - results = [] - for batch_start in range(0, n, batch_size): - # create batch of titles only for passages - ctx_title = None - if self.embed_title and titles: - ctx_title = titles[batch_start:batch_start + batch_size] - - # create batch of text - ctx_text = texts[batch_start:batch_start + batch_size] - - # tensorize the batch - ctx_ids_batch, _, ctx_attn_mask = self._tensorizer(tokenizer, text=ctx_text, title=ctx_title) - ctx_seg_batch = torch.zeros_like(ctx_ids_batch).to(self.device) - - # remove [SEP] token from untitled passages in batch - if self.embed_title and self.remove_sep_tok_from_untitled_passages and ctx_title: - ctx_ids_batch, ctx_attn_mask = self._remove_sep_tok_from_untitled_passages(ctx_title, - ctx_ids_batch, - ctx_attn_mask) - - with torch.no_grad(): - out = model(input_ids=ctx_ids_batch, attention_mask=ctx_attn_mask, token_type_ids=ctx_seg_batch) - # TODO revert back to when updating transformers - # out = out.pooler_output - out = out[0] - out = out.cpu() - - total += ctx_ids_batch.size()[0] - - results.extend([ - (out[i].view(-1).numpy()) - for i in range(out.size(0)) - ]) - - if total % 10 == 0: - logger.info(f'Embedded {total} / {n} texts') - - return results class EmbeddingRetriever(BaseRetriever): def __init__( diff --git a/requirements.txt b/requirements.txt index b4100dda8..409df00bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,4 +21,4 @@ tika uvloop; sys_platform != 'win32' and sys_platform != 'cygwin' httptools nltk -more_itertools \ No newline at end of file +more_itertools diff --git a/test/conftest.py b/test/conftest.py index 5f64485a4..af6888eb0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -180,7 +180,7 @@ def dpr_retriever(faiss_document_store): passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", use_gpu=False, embed_title=True, - remove_sep_tok_from_untitled_passages=True + use_fast_tokenizers=True ) @@ -288,14 +288,10 @@ def get_document_store(document_store_type, faiss_document_store, inmemory_docum def get_retriever(retriever_type, document_store): if retriever_type == "dpr": - retriever = DensePassageRetriever( - document_store=document_store, - query_embedding_model="facebook/dpr-question_encoder-single-nq-base", - passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", - use_gpu=False, - embed_title=True, - remove_sep_tok_from_untitled_passages=True - ) + retriever = DensePassageRetriever(document_store=document_store, + query_embedding_model="facebook/dpr-question_encoder-single-nq-base", + passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", + use_gpu=False, embed_title=True) elif retriever_type == "tfidf": return TfidfRetriever(document_store=document_store) elif retriever_type == "embedding": diff --git a/test/test_dpr_retriever.py b/test/test_dpr_retriever.py index 432aad9ac..f1e892066 100644 --- a/test/test_dpr_retriever.py +++ b/test/test_dpr_retriever.py @@ -45,11 +45,11 @@ def test_dpr_retrieval(document_store, retriever, return_embedding): # FAISSDocumentStore doesn't return embeddings, so these tests only work with ElasticsearchDocumentStore if isinstance(document_store, ElasticsearchDocumentStore): assert (len(docs_with_emb[0].embedding) == 768) - assert (abs(docs_with_emb[0].embedding[0] - (-0.30634)) < 0.001) - assert (abs(docs_with_emb[1].embedding[0] - (-0.37449)) < 0.001) - assert (abs(docs_with_emb[2].embedding[0] - (-0.24695)) < 0.001) - assert (abs(docs_with_emb[3].embedding[0] - (-0.08017)) < 0.001) - assert (abs(docs_with_emb[4].embedding[0] - (-0.01534)) < 0.001) + assert (abs(docs_with_emb[0].embedding[0] - (-0.3063)) < 0.001) + assert (abs(docs_with_emb[1].embedding[0] - (-0.3914)) < 0.001) + assert (abs(docs_with_emb[2].embedding[0] - (-0.2470)) < 0.001) + assert (abs(docs_with_emb[3].embedding[0] - (-0.0802)) < 0.001) + assert (abs(docs_with_emb[4].embedding[0] - (-0.0551)) < 0.001) res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?") diff --git a/tutorials/Tutorial6_Better_Retrieval_via_DPR.ipynb b/tutorials/Tutorial6_Better_Retrieval_via_DPR.ipynb index e92fc216c..eda2065b0 100644 --- a/tutorials/Tutorial6_Better_Retrieval_via_DPR.ipynb +++ b/tutorials/Tutorial6_Better_Retrieval_via_DPR.ipynb @@ -515,11 +515,12 @@ "retriever = DensePassageRetriever(document_store=document_store,\n", " query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n", " passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n", + " max_seq_len_query=64,\n", + " max_seq_len_passage=256,\n", + " batch_size=16,\n", " use_gpu=True,\n", " embed_title=True,\n", - " max_seq_len=256,\n", - " batch_size=16,\n", - " remove_sep_tok_from_untitled_passages=True)\n", + " use_fast_tokenizers=True)\n", "# Important: \n", "# Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all\n", "# previously indexed documents and update their embedding representation. \n", diff --git a/tutorials/Tutorial6_Better_Retrieval_via_DPR.py b/tutorials/Tutorial6_Better_Retrieval_via_DPR.py index c18dda7c5..e9a4a2dcc 100755 --- a/tutorials/Tutorial6_Better_Retrieval_via_DPR.py +++ b/tutorials/Tutorial6_Better_Retrieval_via_DPR.py @@ -29,9 +29,13 @@ document_store.write_documents(dicts) retriever = DensePassageRetriever(document_store=document_store, query_embedding_model="facebook/dpr-question_encoder-single-nq-base", passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", + max_seq_len_query=64, + max_seq_len_passage=256, + batch_size=2, use_gpu=True, embed_title=True, - remove_sep_tok_from_untitled_passages=True) + use_fast_tokenizers=True + ) # Important: # Now that after we have the DPR initialized, we need to call update_embeddings() to iterate over all diff --git a/tutorials/Tutorial7_RAG_Generator.ipynb b/tutorials/Tutorial7_RAG_Generator.ipynb index f4b0f41a3..2618baeb7 100644 --- a/tutorials/Tutorial7_RAG_Generator.ipynb +++ b/tutorials/Tutorial7_RAG_Generator.ipynb @@ -1,196 +1,195 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Tutorial7_RAG_Generator.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Tutorial7_RAG_Generator.ipynb", + "provenance": [], + "collapsed_sections": [] }, - "cells": [ - { - "cell_type": "code", - "metadata": { - "id": "iDyfhfyp7Sjh" - }, - "source": [ - "!pip install git+https://github.com/deepset-ai/haystack.git\n", - "!pip install urllib3==1.25.4" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "ICZanGLa7khF" - }, - "source": [ - "from typing import List\n", - "import requests\n", - "import pandas as pd\n", - "from haystack import Document\n", - "from haystack.document_store.faiss import FAISSDocumentStore\n", - "from haystack.generator.transformers import RAGenerator\n", - "from haystack.retriever.dense import DensePassageRetriever" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "D3f-CQ4c7lEN" - }, - "source": [ - "# Add documents from which you want generate answers\n", - "# Download a csv containing some sample documents data\n", - "# Here some sample documents data\n", - "temp = requests.get(\"https://raw.githubusercontent.com/deepset-ai/haystack/master/tutorials/small_generator_dataset.csv\")\n", - "open('small_generator_dataset.csv', 'wb').write(temp.content)\n", - "\n", - "# Get dataframe with columns \"title\", and \"text\"\n", - "df = pd.read_csv(\"small_generator_dataset.csv\", sep=',')\n", - "# Minimal cleaning\n", - "df.fillna(value=\"\", inplace=True)\n", - "\n", - "print(df.head())\n", - "\n", - "# Create to haystack document format\n", - "titles = list(df[\"title\"].values)\n", - "texts = list(df[\"text\"].values)\n", - "\n", - "documents: List[Document] = []\n", - "for title, text in zip(titles, texts):\n", - " documents.append(\n", - " Document(\n", - " text=text,\n", - " meta={\n", - " \"name\": title or \"\"\n", - " }\n", - " )\n", - " )" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "upRu3ebX7nr_" - }, - "source": [ - "# Initialize FAISS document store to documents and corresponding index for embeddings\n", - "# Set `return_embedding` to `True`, so generator doesn't have to perform re-embedding\n", - "document_store = FAISSDocumentStore(\n", - " faiss_index_factory_str=\"Flat\",\n", - " return_embedding=True\n", - ")\n", - "\n", - "# Initialize DPR Retriever to encode documents, encode question and query documents\n", - "retriever = DensePassageRetriever(\n", - " document_store=document_store,\n", - " query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n", - " passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n", - " use_gpu=False,\n", - " embed_title=True,\n", - " remove_sep_tok_from_untitled_passages=True,\n", - ")\n", - "\n", - "# Initialize RAG Generator\n", - "generator = RAGenerator(\n", - " model_name_or_path=\"facebook/rag-token-nq\",\n", - " use_gpu=False,\n", - " top_k_answers=1,\n", - " max_length=200,\n", - " min_length=2,\n", - " embed_title=True,\n", - " num_beams=2,\n", - ")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "as8j7hkW7rOW" - }, - "source": [ - "# Delete existing documents in documents store\n", - "document_store.delete_all_documents()\n", - "# Write documents to document store\n", - "document_store.write_documents(documents)\n", - "# Add documents embeddings to index\n", - "document_store.update_embeddings(\n", - " retriever=retriever\n", - ")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "j8It45R872vb", - "cellView": "form" - }, - "source": [ - "#@title\n", - "# Now ask your questions\n", - "# We have some sample questions\n", - "QUESTIONS = [\n", - " \"who got the first nobel prize in physics\",\n", - " \"when is the next deadpool movie being released\",\n", - " \"which mode is used for short wave broadcast service\",\n", - " \"who is the owner of reading football club\",\n", - " \"when is the next scandal episode coming out\",\n", - " \"when is the last time the philadelphia won the superbowl\",\n", - " \"what is the most current adobe flash player version\",\n", - " \"how many episodes are there in dragon ball z\",\n", - " \"what is the first step in the evolution of the eye\",\n", - " \"where is gall bladder situated in human body\",\n", - " \"what is the main mineral in lithium batteries\",\n", - " \"who is the president of usa right now\",\n", - " \"where do the greasers live in the outsiders\",\n", - " \"panda is a national animal of which country\",\n", - " \"what is the name of manchester united stadium\",\n", - "]" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "xPUHRuTP742h" - }, - "source": [ - "# Now generate answer for question\n", - "for question in QUESTIONS:\n", - " # Retrieve related documents from retriever\n", - " retriever_results = retriever.retrieve(\n", - " query=question\n", - " )\n", - "\n", - " # Now generate answer from question and retrieved documents\n", - " predicted_result = generator.predict(\n", - " question=question,\n", - " documents=retriever_results,\n", - " top_k=1\n", - " )\n", - "\n", - " # Print you answer\n", - " answers = predicted_result[\"answers\"]\n", - " print(f'Generated answer is \\'{answers[0][\"answer\"]}\\' for the question = \\'{question}\\'')" - ], - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "id": "iDyfhfyp7Sjh" + }, + "source": [ + "!pip install git+https://github.com/deepset-ai/haystack.git\n", + "!pip install urllib3==1.25.4" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "ICZanGLa7khF" + }, + "source": [ + "from typing import List\n", + "import requests\n", + "import pandas as pd\n", + "from haystack import Document\n", + "from haystack.document_store.faiss import FAISSDocumentStore\n", + "from haystack.generator.transformers import RAGenerator\n", + "from haystack.retriever.dense import DensePassageRetriever" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "D3f-CQ4c7lEN" + }, + "source": [ + "# Add documents from which you want generate answers\n", + "# Download a csv containing some sample documents data\n", + "# Here some sample documents data\n", + "temp = requests.get(\"https://raw.githubusercontent.com/deepset-ai/haystack/master/tutorials/small_generator_dataset.csv\")\n", + "open('small_generator_dataset.csv', 'wb').write(temp.content)\n", + "\n", + "# Get dataframe with columns \"title\", and \"text\"\n", + "df = pd.read_csv(\"small_generator_dataset.csv\", sep=',')\n", + "# Minimal cleaning\n", + "df.fillna(value=\"\", inplace=True)\n", + "\n", + "print(df.head())\n", + "\n", + "# Create to haystack document format\n", + "titles = list(df[\"title\"].values)\n", + "texts = list(df[\"text\"].values)\n", + "\n", + "documents: List[Document] = []\n", + "for title, text in zip(titles, texts):\n", + " documents.append(\n", + " Document(\n", + " text=text,\n", + " meta={\n", + " \"name\": title or \"\"\n", + " }\n", + " )\n", + " )" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "upRu3ebX7nr_" + }, + "source": [ + "# Initialize FAISS document store to documents and corresponding index for embeddings\n", + "# Set `return_embedding` to `True`, so generator doesn't have to perform re-embedding\n", + "document_store = FAISSDocumentStore(\n", + " faiss_index_factory_str=\"Flat\",\n", + " return_embedding=True\n", + ")\n", + "\n", + "# Initialize DPR Retriever to encode documents, encode question and query documents\n", + "retriever = DensePassageRetriever(\n", + " document_store=document_store,\n", + " query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n", + " passage_embedding_model=\"facebook/dpr-ctx_encoder-single-nq-base\",\n", + " use_gpu=False,\n", + " embed_title=True,\n", + ")\n", + "\n", + "# Initialize RAG Generator\n", + "generator = RAGenerator(\n", + " model_name_or_path=\"facebook/rag-token-nq\",\n", + " use_gpu=False,\n", + " top_k_answers=1,\n", + " max_length=200,\n", + " min_length=2,\n", + " embed_title=True,\n", + " num_beams=2,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "as8j7hkW7rOW" + }, + "source": [ + "# Delete existing documents in documents store\n", + "document_store.delete_all_documents()\n", + "# Write documents to document store\n", + "document_store.write_documents(documents)\n", + "# Add documents embeddings to index\n", + "document_store.update_embeddings(\n", + " retriever=retriever\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "j8It45R872vb", + "cellView": "form" + }, + "source": [ + "#@title\n", + "# Now ask your questions\n", + "# We have some sample questions\n", + "QUESTIONS = [\n", + " \"who got the first nobel prize in physics\",\n", + " \"when is the next deadpool movie being released\",\n", + " \"which mode is used for short wave broadcast service\",\n", + " \"who is the owner of reading football club\",\n", + " \"when is the next scandal episode coming out\",\n", + " \"when is the last time the philadelphia won the superbowl\",\n", + " \"what is the most current adobe flash player version\",\n", + " \"how many episodes are there in dragon ball z\",\n", + " \"what is the first step in the evolution of the eye\",\n", + " \"where is gall bladder situated in human body\",\n", + " \"what is the main mineral in lithium batteries\",\n", + " \"who is the president of usa right now\",\n", + " \"where do the greasers live in the outsiders\",\n", + " \"panda is a national animal of which country\",\n", + " \"what is the name of manchester united stadium\",\n", + "]" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "xPUHRuTP742h" + }, + "source": [ + "# Now generate answer for question\n", + "for question in QUESTIONS:\n", + " # Retrieve related documents from retriever\n", + " retriever_results = retriever.retrieve(\n", + " query=question\n", + " )\n", + "\n", + " # Now generate answer from question and retrieved documents\n", + " predicted_result = generator.predict(\n", + " question=question,\n", + " documents=retriever_results,\n", + " top_k=1\n", + " )\n", + "\n", + " # Print you answer\n", + " answers = predicted_result[\"answers\"]\n", + " print(f'Generated answer is \\'{answers[0][\"answer\"]}\\' for the question = \\'{question}\\'')" + ], + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/tutorials/Tutorial7_RAG_Generator.py b/tutorials/Tutorial7_RAG_Generator.py index 09b958460..d6f96e3f7 100644 --- a/tutorials/Tutorial7_RAG_Generator.py +++ b/tutorials/Tutorial7_RAG_Generator.py @@ -50,7 +50,6 @@ retriever = DensePassageRetriever( passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", use_gpu=False, embed_title=True, - remove_sep_tok_from_untitled_passages=True, ) # Initialize RAG Generator