Allow non-standard Tokenizers (e.g. CamemBERT) for DPR via new arg (#811)

* added parameter to infer DPR tokenizers class

* Add latest docstring and tutorial changes

* Update docstring. fix mypy

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
Pavel Soriano 2021-02-12 14:17:55 +01:00 committed by GitHub
parent c4607cbd98
commit 8adf5b4737
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 3 deletions

View File

@ -225,7 +225,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
#### \_\_init\_\_ #### \_\_init\_\_
```python ```python
| __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", progress_bar: bool = True) | __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", progress_bar: bool = True)
``` ```
Init the Retriever incl. the two encoder models from a local or remote model checkpoint. Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@ -264,6 +264,11 @@ titles contain meaningful information for retrieval (topic, entities etc.) .
The title is expected to be present in doc.meta["name"] and can be supplied in the documents The title is expected to be present in doc.meta["name"] and can be supplied in the documents
before writing them to the DocumentStore like this: before writing them to the DocumentStore like this:
{"text": "my text", "meta": {"name": "my title"}}. {"text": "my text", "meta": {"name": "my title"}}.
- `use_fast_tokenizers`: Whether to use fast Rust tokenizers
- `infer_tokenizer_classes`: Whether to infer tokenizer class from the model config / name.
If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`.
- `similarity_function`: Which function to apply for calculating the similarity of query and passage embeddings during training.
Options: `dot_product` (Default) or `cosine`
- `progress_bar`: Whether to show a tqdm progress bar or not. - `progress_bar`: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean. Can be helpful to disable in production deployments to keep the logs clean.

View File

@ -46,6 +46,7 @@ class DensePassageRetriever(BaseRetriever):
batch_size: int = 16, batch_size: int = 16,
embed_title: bool = True, embed_title: bool = True,
use_fast_tokenizers: bool = True, use_fast_tokenizers: bool = True,
infer_tokenizer_classes: bool = False,
similarity_function: str = "dot_product", similarity_function: str = "dot_product",
progress_bar: bool = True progress_bar: bool = True
): ):
@ -84,6 +85,11 @@ class DensePassageRetriever(BaseRetriever):
The title is expected to be present in doc.meta["name"] and can be supplied in the documents The title is expected to be present in doc.meta["name"] and can be supplied in the documents
before writing them to the DocumentStore like this: before writing them to the DocumentStore like this:
{"text": "my text", "meta": {"name": "my title"}}. {"text": "my text", "meta": {"name": "my title"}}.
:param use_fast_tokenizers: Whether to use fast Rust tokenizers
:param infer_tokenizer_classes: Whether to infer tokenizer class from the model config / name.
If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`.
:param similarity_function: Which function to apply for calculating the similarity of query and passage embeddings during training.
Options: `dot_product` (Default) or `cosine`
:param progress_bar: Whether to show a tqdm progress bar or not. :param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean. Can be helpful to disable in production deployments to keep the logs clean.
""" """
@ -109,13 +115,21 @@ class DensePassageRetriever(BaseRetriever):
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.embed_title = embed_title self.embed_title = embed_title
self.infer_tokenizer_classes = infer_tokenizer_classes
tokenizers_default_classes = {
"query": "DPRQuestionEncoderTokenizer",
"passage": "DPRContextEncoderTokenizer"
}
if self.infer_tokenizer_classes:
tokenizers_default_classes["query"] = None # type: ignore
tokenizers_default_classes["passage"] = None # type: ignore
# Init & Load Encoders # Init & Load Encoders
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model, self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
revision=model_version, revision=model_version,
do_lower_case=True, do_lower_case=True,
use_fast=use_fast_tokenizers, use_fast=use_fast_tokenizers,
tokenizer_class="DPRQuestionEncoderTokenizer") tokenizer_class=tokenizers_default_classes["query"])
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model, self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
revision=model_version, revision=model_version,
language_model_class="DPRQuestionEncoder") language_model_class="DPRQuestionEncoder")
@ -123,7 +137,7 @@ class DensePassageRetriever(BaseRetriever):
revision=model_version, revision=model_version,
do_lower_case=True, do_lower_case=True,
use_fast=use_fast_tokenizers, use_fast=use_fast_tokenizers,
tokenizer_class="DPRContextEncoderTokenizer") tokenizer_class=tokenizers_default_classes["passage"])
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model, self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version, revision=model_version,
language_model_class="DPRContextEncoder") language_model_class="DPRContextEncoder")