diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index db1211812..ed6793e68 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -225,7 +225,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que #### \_\_init\_\_ ```python - | __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", progress_bar: bool = True) + | __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", progress_bar: bool = True) ``` Init the Retriever incl. the two encoder models from a local or remote model checkpoint. @@ -264,6 +264,11 @@ titles contain meaningful information for retrieval (topic, entities etc.) . The title is expected to be present in doc.meta["name"] and can be supplied in the documents before writing them to the DocumentStore like this: {"text": "my text", "meta": {"name": "my title"}}. +- `use_fast_tokenizers`: Whether to use fast Rust tokenizers +- `infer_tokenizer_classes`: Whether to infer tokenizer class from the model config / name. +If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`. +- `similarity_function`: Which function to apply for calculating the similarity of query and passage embeddings during training. +Options: `dot_product` (Default) or `cosine` - `progress_bar`: Whether to show a tqdm progress bar or not. Can be helpful to disable in production deployments to keep the logs clean. diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index fe5d207cc..90a619481 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -46,6 +46,7 @@ class DensePassageRetriever(BaseRetriever): batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, + infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", progress_bar: bool = True ): @@ -84,6 +85,11 @@ class DensePassageRetriever(BaseRetriever): The title is expected to be present in doc.meta["name"] and can be supplied in the documents before writing them to the DocumentStore like this: {"text": "my text", "meta": {"name": "my title"}}. + :param use_fast_tokenizers: Whether to use fast Rust tokenizers + :param infer_tokenizer_classes: Whether to infer tokenizer class from the model config / name. + If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`. + :param similarity_function: Which function to apply for calculating the similarity of query and passage embeddings during training. + Options: `dot_product` (Default) or `cosine` :param progress_bar: Whether to show a tqdm progress bar or not. Can be helpful to disable in production deployments to keep the logs clean. """ @@ -109,13 +115,21 @@ class DensePassageRetriever(BaseRetriever): self.device = torch.device("cpu") self.embed_title = embed_title + self.infer_tokenizer_classes = infer_tokenizer_classes + tokenizers_default_classes = { + "query": "DPRQuestionEncoderTokenizer", + "passage": "DPRContextEncoderTokenizer" + } + if self.infer_tokenizer_classes: + tokenizers_default_classes["query"] = None # type: ignore + tokenizers_default_classes["passage"] = None # type: ignore # Init & Load Encoders self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model, revision=model_version, do_lower_case=True, use_fast=use_fast_tokenizers, - tokenizer_class="DPRQuestionEncoderTokenizer") + tokenizer_class=tokenizers_default_classes["query"]) self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model, revision=model_version, language_model_class="DPRQuestionEncoder") @@ -123,7 +137,7 @@ class DensePassageRetriever(BaseRetriever): revision=model_version, do_lower_case=True, use_fast=use_fast_tokenizers, - tokenizer_class="DPRContextEncoderTokenizer") + tokenizer_class=tokenizers_default_classes["passage"]) self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model, revision=model_version, language_model_class="DPRContextEncoder")