mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-24 08:28:22 +00:00
Allow non-standard Tokenizers (e.g. CamemBERT) for DPR via new arg (#811)
* added parameter to infer DPR tokenizers class * Add latest docstring and tutorial changes * Update docstring. fix mypy * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
c4607cbd98
commit
8adf5b4737
@ -225,7 +225,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
|
|||||||
#### \_\_init\_\_
|
#### \_\_init\_\_
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", progress_bar: bool = True)
|
| __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", progress_bar: bool = True)
|
||||||
```
|
```
|
||||||
|
|
||||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||||
@ -264,6 +264,11 @@ titles contain meaningful information for retrieval (topic, entities etc.) .
|
|||||||
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
|
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
|
||||||
before writing them to the DocumentStore like this:
|
before writing them to the DocumentStore like this:
|
||||||
{"text": "my text", "meta": {"name": "my title"}}.
|
{"text": "my text", "meta": {"name": "my title"}}.
|
||||||
|
- `use_fast_tokenizers`: Whether to use fast Rust tokenizers
|
||||||
|
- `infer_tokenizer_classes`: Whether to infer tokenizer class from the model config / name.
|
||||||
|
If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`.
|
||||||
|
- `similarity_function`: Which function to apply for calculating the similarity of query and passage embeddings during training.
|
||||||
|
Options: `dot_product` (Default) or `cosine`
|
||||||
- `progress_bar`: Whether to show a tqdm progress bar or not.
|
- `progress_bar`: Whether to show a tqdm progress bar or not.
|
||||||
Can be helpful to disable in production deployments to keep the logs clean.
|
Can be helpful to disable in production deployments to keep the logs clean.
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
embed_title: bool = True,
|
embed_title: bool = True,
|
||||||
use_fast_tokenizers: bool = True,
|
use_fast_tokenizers: bool = True,
|
||||||
|
infer_tokenizer_classes: bool = False,
|
||||||
similarity_function: str = "dot_product",
|
similarity_function: str = "dot_product",
|
||||||
progress_bar: bool = True
|
progress_bar: bool = True
|
||||||
):
|
):
|
||||||
@ -84,6 +85,11 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
|
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
|
||||||
before writing them to the DocumentStore like this:
|
before writing them to the DocumentStore like this:
|
||||||
{"text": "my text", "meta": {"name": "my title"}}.
|
{"text": "my text", "meta": {"name": "my title"}}.
|
||||||
|
:param use_fast_tokenizers: Whether to use fast Rust tokenizers
|
||||||
|
:param infer_tokenizer_classes: Whether to infer tokenizer class from the model config / name.
|
||||||
|
If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`.
|
||||||
|
:param similarity_function: Which function to apply for calculating the similarity of query and passage embeddings during training.
|
||||||
|
Options: `dot_product` (Default) or `cosine`
|
||||||
:param progress_bar: Whether to show a tqdm progress bar or not.
|
:param progress_bar: Whether to show a tqdm progress bar or not.
|
||||||
Can be helpful to disable in production deployments to keep the logs clean.
|
Can be helpful to disable in production deployments to keep the logs clean.
|
||||||
"""
|
"""
|
||||||
@ -109,13 +115,21 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
self.embed_title = embed_title
|
self.embed_title = embed_title
|
||||||
|
self.infer_tokenizer_classes = infer_tokenizer_classes
|
||||||
|
tokenizers_default_classes = {
|
||||||
|
"query": "DPRQuestionEncoderTokenizer",
|
||||||
|
"passage": "DPRContextEncoderTokenizer"
|
||||||
|
}
|
||||||
|
if self.infer_tokenizer_classes:
|
||||||
|
tokenizers_default_classes["query"] = None # type: ignore
|
||||||
|
tokenizers_default_classes["passage"] = None # type: ignore
|
||||||
|
|
||||||
# Init & Load Encoders
|
# Init & Load Encoders
|
||||||
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
|
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
|
||||||
revision=model_version,
|
revision=model_version,
|
||||||
do_lower_case=True,
|
do_lower_case=True,
|
||||||
use_fast=use_fast_tokenizers,
|
use_fast=use_fast_tokenizers,
|
||||||
tokenizer_class="DPRQuestionEncoderTokenizer")
|
tokenizer_class=tokenizers_default_classes["query"])
|
||||||
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
|
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
|
||||||
revision=model_version,
|
revision=model_version,
|
||||||
language_model_class="DPRQuestionEncoder")
|
language_model_class="DPRQuestionEncoder")
|
||||||
@ -123,7 +137,7 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
revision=model_version,
|
revision=model_version,
|
||||||
do_lower_case=True,
|
do_lower_case=True,
|
||||||
use_fast=use_fast_tokenizers,
|
use_fast=use_fast_tokenizers,
|
||||||
tokenizer_class="DPRContextEncoderTokenizer")
|
tokenizer_class=tokenizers_default_classes["passage"])
|
||||||
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
|
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||||
revision=model_version,
|
revision=model_version,
|
||||||
language_model_class="DPRContextEncoder")
|
language_model_class="DPRContextEncoder")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user