From a8c2cdc5655bfd9f544ca44f1a439cfc2c1ff753 Mon Sep 17 00:00:00 2001
From: Kristof Herrmann <37148029+ArzelaAscoIi@users.noreply.github.com>
Date: Mon, 22 Nov 2021 09:24:02 +0100
Subject: [PATCH] private hugging face models for retrievers (#1785)
* private dpr
* Add latest docstring and tutorial changes
* added parameters to child functions
* Add latest docstring and tutorial changes
* added tableextractor
* Add latest docstring and tutorial changes
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
---
docs/_src/api/api/document_store.md | 6 +-
docs/_src/api/api/retriever.md | 17 ++++--
haystack/modeling/conversion/transformers.py | 5 +-
haystack/modeling/infer.py | 3 +
haystack/modeling/model/adaptive_model.py | 3 +-
haystack/modeling/model/language_model.py | 14 ++---
.../nodes/retriever/_embedding_encoder.py | 2 +-
haystack/nodes/retriever/dense.py | 55 +++++++++++++------
8 files changed, 73 insertions(+), 32 deletions(-)
diff --git a/docs/_src/api/api/document_store.md b/docs/_src/api/api/document_store.md
index 3db306c5b..d35095d22 100644
--- a/docs/_src/api/api/document_store.md
+++ b/docs/_src/api/api/document_store.md
@@ -160,7 +160,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore)
#### \_\_init\_\_
```python
- | __init__(host: Union[str, List[str]] = "localhost", port: Union[int, List[int]] = 9200, username: str = "", password: str = "", api_key_id: Optional[str] = None, api_key: Optional[str] = None, aws4auth=None, index: str = "document", label_index: str = "label", search_fields: Union[str, list] = "content", content_field: str = "content", name_field: str = "name", embedding_field: str = "embedding", embedding_dim: int = 768, custom_mapping: Optional[dict] = None, excluded_meta_data: Optional[list] = None, analyzer: str = "standard", scheme: str = "http", ca_certs: Optional[str] = None, verify_certs: bool = True, create_index: bool = True, refresh_type: str = "wait_for", similarity="dot_product", timeout=30, return_embedding: bool = False, duplicate_documents: str = 'overwrite', index_type: str = "flat", scroll: str = "1d")
+ | __init__(host: Union[str, List[str]] = "localhost", port: Union[int, List[int]] = 9200, username: str = "", password: str = "", api_key_id: Optional[str] = None, api_key: Optional[str] = None, aws4auth=None, index: str = "document", label_index: str = "label", search_fields: Union[str, list] = "content", content_field: str = "content", name_field: str = "name", embedding_field: str = "embedding", embedding_dim: int = 768, custom_mapping: Optional[dict] = None, excluded_meta_data: Optional[list] = None, analyzer: str = "standard", scheme: str = "http", ca_certs: Optional[str] = None, verify_certs: bool = True, create_index: bool = True, refresh_type: str = "wait_for", similarity="dot_product", timeout=30, return_embedding: bool = False, duplicate_documents: str = 'overwrite', index_type: str = "flat", scroll: str = "1d", skip_missing_embeddings: bool = True)
```
A DocumentStore using Elasticsearch to store and query the documents for our search.
@@ -215,6 +215,10 @@ A DocumentStore using Elasticsearch to store and query the documents for our sea
- `scroll`: Determines how long the current index is fixed, e.g. during updating all documents with embeddings.
Defaults to "1d" and should not be larger than this. Can also be in minutes "5m" or hours "15h"
For details, see https://www.elastic.co/guide/en/elasticsearch/reference/current/scroll-api.html
+- `skip_missing_embeddings`: Parameter to control queries based on vector similarity when indexed documents miss embeddings.
+ Parameter options: (True, False)
+ False: Raises exception if one or more documents do not have embeddings at query time
+ True: Query will ignore all documents without embeddings (recommended if you concurrently index and query)
#### get\_document\_by\_id
diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md
index a782904c5..140728aa9 100644
--- a/docs/_src/api/api/retriever.md
+++ b/docs/_src/api/api/retriever.md
@@ -250,7 +250,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
#### \_\_init\_\_
```python
- | __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None)
+ | __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str,bool]] = None)
```
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@@ -301,6 +301,9 @@ The checkpoint format matches huggingface transformers' model format
Can be helpful to disable in production deployments to keep the logs clean.
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
+- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
+ the local token will be used, which must be previously created via `transformer-cli login`.
+ Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
#### retrieve
@@ -442,7 +445,7 @@ Kostić, Bogdan, et al. (2021): "Multi-modal Retrieval of Tables and Texts Using
#### \_\_init\_\_
```python
- | __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None)
+ | __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str,bool]] = None)
```
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@@ -479,6 +482,9 @@ The checkpoint format matches huggingface transformers' model format
Can be helpful to disable in production deployments to keep the logs clean.
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
+- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
+ the local token will be used, which must be previously created via `transformer-cli login`.
+ Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
#### embed\_queries
@@ -605,7 +611,7 @@ class EmbeddingRetriever(BaseRetriever)
#### \_\_init\_\_
```python
- | __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None)
+ | __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str,bool]] = None)
```
**Arguments**:
@@ -631,7 +637,10 @@ class EmbeddingRetriever(BaseRetriever)
- `top_k`: How many documents to return per query.
- `progress_bar`: If true displays progress bar during embedding.
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
- As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
+ As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
+- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
+ the local token will be used, which must be previously created via `transformer-cli login`.
+ Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
#### retrieve
diff --git a/haystack/modeling/conversion/transformers.py b/haystack/modeling/conversion/transformers.py
index 1827b4983..ac995671d 100644
--- a/haystack/modeling/conversion/transformers.py
+++ b/haystack/modeling/conversion/transformers.py
@@ -1,4 +1,5 @@
import logging
+from typing import Union
from transformers import AutoModelForQuestionAnswering
@@ -42,7 +43,7 @@ class Converter:
return converted_models
@staticmethod
- def convert_from_transformers(model_name_or_path, device, revision=None, task_type=None, processor=None, **kwargs):
+ def convert_from_transformers(model_name_or_path, device, revision=None, task_type=None, processor=None, use_auth_token: Union[bool, str] = None, **kwargs):
"""
Load a (downstream) model from huggingface's transformers format. Use cases:
- continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data)
@@ -66,7 +67,7 @@ class Converter:
:return: AdaptiveModel
"""
- lm = LanguageModel.load(model_name_or_path, revision=revision, **kwargs)
+ lm = LanguageModel.load(model_name_or_path, revision=revision,use_auth_token=use_auth_token, **kwargs)
if task_type is None:
# Infer task type from config
architecture = lm.model.config.architectures[0]
diff --git a/haystack/modeling/infer.py b/haystack/modeling/infer.py
index cd105f3ec..b4b0fd719 100644
--- a/haystack/modeling/infer.py
+++ b/haystack/modeling/infer.py
@@ -120,6 +120,7 @@ class Inferencer:
tokenizer_args: Dict =None,
multithreading_rust: bool = True,
devices: Optional[List[Union[int, str, torch.device]]] = None,
+ use_auth_token: Union[bool, str] = None,
**kwargs
):
"""
@@ -189,6 +190,7 @@ class Inferencer:
revision=revision,
device=devices[0], # type: ignore
task_type=task_type,
+ use_auth_token=use_auth_token,
**kwargs)
processor = Processor.convert_from_transformers(model_name_or_path,
revision=revision,
@@ -198,6 +200,7 @@ class Inferencer:
tokenizer_class=tokenizer_class,
tokenizer_args=tokenizer_args,
use_fast=use_fast,
+ use_auth_token=use_auth_token,
**kwargs)
# override processor attributes loaded from config or HF with inferencer params
diff --git a/haystack/modeling/model/adaptive_model.py b/haystack/modeling/model/adaptive_model.py
index 089e3e740..42e892f41 100644
--- a/haystack/modeling/model/adaptive_model.py
+++ b/haystack/modeling/model/adaptive_model.py
@@ -467,7 +467,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
@classmethod
def convert_from_transformers(cls, model_name_or_path: Union[str, Path], device: str, revision: Optional[str] = None,
- task_type: Optional[str] = None, processor: Optional[Processor] = None, **kwargs):
+ task_type: Optional[str] = None, processor: Optional[Processor] = None, use_auth_token: Optional[Union[bool, str]] = None, **kwargs):
"""
Load a (downstream) model from huggingface's transformers format. Use cases:
- continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data)
@@ -494,6 +494,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
device=device,
task_type=task_type,
processor=processor,
+ use_auth_token=use_auth_token,
**kwargs)
diff --git a/haystack/modeling/model/language_model.py b/haystack/modeling/model/language_model.py
index 2eccd7c19..806b04594 100644
--- a/haystack/modeling/model/language_model.py
+++ b/haystack/modeling/model/language_model.py
@@ -1149,11 +1149,11 @@ class DPRQuestionEncoder(LanguageModel):
dpr_question_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(str(pretrained_model_name_or_path)).model
dpr_question_encoder.language = dpr_question_encoder.model.config.language
else:
- original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
+ original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, use_auth_token=use_auth_token)
if original_model_config.model_type == "dpr":
# "pretrained dpr model": load existing pretrained DPRQuestionEncoder model
dpr_question_encoder.model = transformers.DPRQuestionEncoder.from_pretrained(
- str(pretrained_model_name_or_path), **kwargs)
+ str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **kwargs)
else:
# "from scratch": load weights from different architecture (e.g. bert) into DPRQuestionEncoder
# but keep config values from original architecture
@@ -1165,7 +1165,7 @@ class DPRQuestionEncoder(LanguageModel):
original_config_dict.update(kwargs)
dpr_question_encoder.model = transformers.DPRQuestionEncoder(config=transformers.DPRConfig(**original_config_dict))
dpr_question_encoder.model.base_model.bert_model = AutoModel.from_pretrained(
- str(pretrained_model_name_or_path), **original_config_dict)
+ str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **original_config_dict)
dpr_question_encoder.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
return dpr_question_encoder
@@ -1281,16 +1281,16 @@ class DPRContextEncoder(LanguageModel):
dpr_context_encoder.model = transformers.DPRContextEncoder(config=transformers.DPRConfig(**original_config_dict))
language_model_class = cls.get_language_model_class(haystack_lm_config, **kwargs)
dpr_context_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(
- str(pretrained_model_name_or_path)).model
+ str(pretrained_model_name_or_path), use_auth_token=use_auth_token).model
dpr_context_encoder.language = dpr_context_encoder.model.config.language
else:
# Pytorch-transformer Style
- original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
+ original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, use_auth_token=use_auth_token)
if original_model_config.model_type == "dpr":
# "pretrained dpr model": load existing pretrained DPRContextEncoder model
dpr_context_encoder.model = transformers.DPRContextEncoder.from_pretrained(
- str(pretrained_model_name_or_path), **kwargs)
+ str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **kwargs)
else:
# "from scratch": load weights from different architecture (e.g. bert) into DPRContextEncoder
# but keep config values from original architecture
@@ -1304,7 +1304,7 @@ class DPRContextEncoder(LanguageModel):
dpr_context_encoder.model = transformers.DPRContextEncoder(
config=transformers.DPRConfig(**original_config_dict))
dpr_context_encoder.model.base_model.bert_model = AutoModel.from_pretrained(
- str(pretrained_model_name_or_path), **original_config_dict)
+ str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **original_config_dict)
dpr_context_encoder.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
return dpr_context_encoder
diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py
index 37c94dac8..da4617ff8 100644
--- a/haystack/nodes/retriever/_embedding_encoder.py
+++ b/haystack/nodes/retriever/_embedding_encoder.py
@@ -55,7 +55,7 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
retriever.embedding_model, revision=retriever.model_version, task_type="embeddings",
extraction_strategy=retriever.pooling_strategy,
extraction_layer=retriever.emb_extraction_layer, gpu=retriever.use_gpu,
- batch_size=4, max_seq_len=512, num_processes=0
+ batch_size=4, max_seq_len=512, num_processes=0,use_auth_token=retriever.use_auth_token
)
# Check that document_store has the right similarity function
similarity = retriever.document_store.similarity
diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py
index febda90fd..c5635eff6 100644
--- a/haystack/nodes/retriever/dense.py
+++ b/haystack/nodes/retriever/dense.py
@@ -52,7 +52,8 @@ class DensePassageRetriever(BaseRetriever):
similarity_function: str = "dot_product",
global_loss_buffer_size: int = 150000,
progress_bar: bool = True,
- devices: Optional[List[Union[int, str, torch.device]]] = None
+ devices: Optional[List[Union[int, str, torch.device]]] = None,
+ use_auth_token: Optional[Union[str,bool]] = None,
):
"""
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@@ -101,6 +102,9 @@ class DensePassageRetriever(BaseRetriever):
Can be helpful to disable in production deployments to keep the logs clean.
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
+ :param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
+ the local token will be used, which must be previously created via `transformer-cli login`.
+ Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
"""
# save init parameters to enable export of component config as YAML
self.set_config(
@@ -109,7 +113,7 @@ class DensePassageRetriever(BaseRetriever):
model_version=model_version, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage,
top_k=top_k, use_gpu=use_gpu, batch_size=batch_size, embed_title=embed_title,
use_fast_tokenizers=use_fast_tokenizers, infer_tokenizer_classes=infer_tokenizer_classes,
- similarity_function=similarity_function, progress_bar=progress_bar, devices=devices
+ similarity_function=similarity_function, progress_bar=progress_bar, devices=devices, use_auth_token=use_auth_token
)
if devices is not None:
@@ -148,18 +152,22 @@ class DensePassageRetriever(BaseRetriever):
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
- tokenizer_class=tokenizers_default_classes["query"])
+ tokenizer_class=tokenizers_default_classes["query"],
+ use_auth_token=use_auth_token)
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
revision=model_version,
- language_model_class="DPRQuestionEncoder")
+ language_model_class="DPRQuestionEncoder",
+ use_auth_token=use_auth_token)
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
- tokenizer_class=tokenizers_default_classes["passage"])
+ tokenizer_class=tokenizers_default_classes["passage"],
+ use_auth_token=use_auth_token)
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
- language_model_class="DPRContextEncoder")
+ language_model_class="DPRContextEncoder",
+ use_auth_token=use_auth_token)
self.processor = TextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
passage_tokenizer=self.passage_tokenizer,
@@ -487,7 +495,8 @@ class TableTextRetriever(BaseRetriever):
similarity_function: str = "dot_product",
global_loss_buffer_size: int = 150000,
progress_bar: bool = True,
- devices: Optional[List[Union[int, str, torch.device]]] = None
+ devices: Optional[List[Union[int, str, torch.device]]] = None,
+ use_auth_token: Optional[Union[str,bool]] = None
):
"""
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@@ -522,6 +531,9 @@ class TableTextRetriever(BaseRetriever):
Can be helpful to disable in production deployments to keep the logs clean.
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
+ :param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
+ the local token will be used, which must be previously created via `transformer-cli login`.
+ Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
"""
# save init parameters to enable export of component config as YAML
self.set_config(
@@ -531,7 +543,7 @@ class TableTextRetriever(BaseRetriever):
max_seq_len_table=max_seq_len_table, top_k=top_k, use_gpu=use_gpu, batch_size=batch_size,
embed_meta_fields=embed_meta_fields, use_fast_tokenizers=use_fast_tokenizers,
infer_tokenizer_classes=infer_tokenizer_classes, similarity_function=similarity_function,
- progress_bar=progress_bar, devices=devices
+ progress_bar=progress_bar, devices=devices, use_auth_token=use_auth_token
)
if devices is not None:
@@ -573,26 +585,32 @@ class TableTextRetriever(BaseRetriever):
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
- tokenizer_class=tokenizers_default_classes["query"])
+ tokenizer_class=tokenizers_default_classes["query"],
+ use_auth_token=use_auth_token)
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
revision=model_version,
- language_model_class="DPRQuestionEncoder")
+ language_model_class="DPRQuestionEncoder",
+ use_auth_token=use_auth_token)
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
- tokenizer_class=tokenizers_default_classes["passage"])
+ tokenizer_class=tokenizers_default_classes["passage"],
+ use_auth_token=use_auth_token)
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
- language_model_class="DPRContextEncoder")
+ language_model_class="DPRContextEncoder",
+ use_auth_token=use_auth_token)
self.table_tokenizer = Tokenizer.load(pretrained_model_name_or_path=table_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
- tokenizer_class=tokenizers_default_classes["table"])
+ tokenizer_class=tokenizers_default_classes["table"],
+ use_auth_token=use_auth_token)
self.table_encoder = LanguageModel.load(pretrained_model_name_or_path=table_embedding_model,
revision=model_version,
- language_model_class="DPRContextEncoder")
+ language_model_class="DPRContextEncoder",
+ use_auth_token=use_auth_token)
self.processor = TableTextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
passage_tokenizer=self.passage_tokenizer,
@@ -942,7 +960,8 @@ class EmbeddingRetriever(BaseRetriever):
emb_extraction_layer: int = -1,
top_k: int = 10,
progress_bar: bool = True,
- devices: Optional[List[Union[int, str, torch.device]]] = None
+ devices: Optional[List[Union[int, str, torch.device]]] = None,
+ use_auth_token: Optional[Union[str,bool]] = None
):
"""
:param document_store: An instance of DocumentStore from which to retrieve documents.
@@ -966,7 +985,10 @@ class EmbeddingRetriever(BaseRetriever):
:param top_k: How many documents to return per query.
:param progress_bar: If true displays progress bar during embedding.
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
- As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
+ As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
+ :param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
+ the local token will be used, which must be previously created via `transformer-cli login`.
+ Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
"""
# save init parameters to enable export of component config as YAML
self.set_config(
@@ -989,6 +1011,7 @@ class EmbeddingRetriever(BaseRetriever):
self.emb_extraction_layer = emb_extraction_layer
self.top_k = top_k
self.progress_bar = progress_bar
+ self.use_auth_token = use_auth_token
logger.info(f"Init retriever using embeddings of model {embedding_model}")