mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-30 01:09:43 +00:00
private hugging face models for retrievers (#1785)
* private dpr * Add latest docstring and tutorial changes * added parameters to child functions * Add latest docstring and tutorial changes * added tableextractor * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
8aa4ca29c2
commit
a8c2cdc565
@ -160,7 +160,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore)
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(host: Union[str, List[str]] = "localhost", port: Union[int, List[int]] = 9200, username: str = "", password: str = "", api_key_id: Optional[str] = None, api_key: Optional[str] = None, aws4auth=None, index: str = "document", label_index: str = "label", search_fields: Union[str, list] = "content", content_field: str = "content", name_field: str = "name", embedding_field: str = "embedding", embedding_dim: int = 768, custom_mapping: Optional[dict] = None, excluded_meta_data: Optional[list] = None, analyzer: str = "standard", scheme: str = "http", ca_certs: Optional[str] = None, verify_certs: bool = True, create_index: bool = True, refresh_type: str = "wait_for", similarity="dot_product", timeout=30, return_embedding: bool = False, duplicate_documents: str = 'overwrite', index_type: str = "flat", scroll: str = "1d")
|
||||
| __init__(host: Union[str, List[str]] = "localhost", port: Union[int, List[int]] = 9200, username: str = "", password: str = "", api_key_id: Optional[str] = None, api_key: Optional[str] = None, aws4auth=None, index: str = "document", label_index: str = "label", search_fields: Union[str, list] = "content", content_field: str = "content", name_field: str = "name", embedding_field: str = "embedding", embedding_dim: int = 768, custom_mapping: Optional[dict] = None, excluded_meta_data: Optional[list] = None, analyzer: str = "standard", scheme: str = "http", ca_certs: Optional[str] = None, verify_certs: bool = True, create_index: bool = True, refresh_type: str = "wait_for", similarity="dot_product", timeout=30, return_embedding: bool = False, duplicate_documents: str = 'overwrite', index_type: str = "flat", scroll: str = "1d", skip_missing_embeddings: bool = True)
|
||||
```
|
||||
|
||||
A DocumentStore using Elasticsearch to store and query the documents for our search.
|
||||
@ -215,6 +215,10 @@ A DocumentStore using Elasticsearch to store and query the documents for our sea
|
||||
- `scroll`: Determines how long the current index is fixed, e.g. during updating all documents with embeddings.
|
||||
Defaults to "1d" and should not be larger than this. Can also be in minutes "5m" or hours "15h"
|
||||
For details, see https://www.elastic.co/guide/en/elasticsearch/reference/current/scroll-api.html
|
||||
- `skip_missing_embeddings`: Parameter to control queries based on vector similarity when indexed documents miss embeddings.
|
||||
Parameter options: (True, False)
|
||||
False: Raises exception if one or more documents do not have embeddings at query time
|
||||
True: Query will ignore all documents without embeddings (recommended if you concurrently index and query)
|
||||
|
||||
<a name="elasticsearch.ElasticsearchDocumentStore.get_document_by_id"></a>
|
||||
#### get\_document\_by\_id
|
||||
|
||||
@ -250,7 +250,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None)
|
||||
| __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str,bool]] = None)
|
||||
```
|
||||
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
@ -301,6 +301,9 @@ The checkpoint format matches huggingface transformers' model format
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
|
||||
<a name="dense.DensePassageRetriever.retrieve"></a>
|
||||
#### retrieve
|
||||
@ -442,7 +445,7 @@ Kostić, Bogdan, et al. (2021): "Multi-modal Retrieval of Tables and Texts Using
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None)
|
||||
| __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str,bool]] = None)
|
||||
```
|
||||
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
@ -479,6 +482,9 @@ The checkpoint format matches huggingface transformers' model format
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
|
||||
<a name="dense.TableTextRetriever.embed_queries"></a>
|
||||
#### embed\_queries
|
||||
@ -605,7 +611,7 @@ class EmbeddingRetriever(BaseRetriever)
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None)
|
||||
| __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str,bool]] = None)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -631,7 +637,10 @@ class EmbeddingRetriever(BaseRetriever)
|
||||
- `top_k`: How many documents to return per query.
|
||||
- `progress_bar`: If true displays progress bar during embedding.
|
||||
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
|
||||
<a name="dense.EmbeddingRetriever.retrieve"></a>
|
||||
#### retrieve
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoModelForQuestionAnswering
|
||||
|
||||
@ -42,7 +43,7 @@ class Converter:
|
||||
return converted_models
|
||||
|
||||
@staticmethod
|
||||
def convert_from_transformers(model_name_or_path, device, revision=None, task_type=None, processor=None, **kwargs):
|
||||
def convert_from_transformers(model_name_or_path, device, revision=None, task_type=None, processor=None, use_auth_token: Union[bool, str] = None, **kwargs):
|
||||
"""
|
||||
Load a (downstream) model from huggingface's transformers format. Use cases:
|
||||
- continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data)
|
||||
@ -66,7 +67,7 @@ class Converter:
|
||||
:return: AdaptiveModel
|
||||
"""
|
||||
|
||||
lm = LanguageModel.load(model_name_or_path, revision=revision, **kwargs)
|
||||
lm = LanguageModel.load(model_name_or_path, revision=revision,use_auth_token=use_auth_token, **kwargs)
|
||||
if task_type is None:
|
||||
# Infer task type from config
|
||||
architecture = lm.model.config.architectures[0]
|
||||
|
||||
@ -120,6 +120,7 @@ class Inferencer:
|
||||
tokenizer_args: Dict =None,
|
||||
multithreading_rust: bool = True,
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None,
|
||||
use_auth_token: Union[bool, str] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@ -189,6 +190,7 @@ class Inferencer:
|
||||
revision=revision,
|
||||
device=devices[0], # type: ignore
|
||||
task_type=task_type,
|
||||
use_auth_token=use_auth_token,
|
||||
**kwargs)
|
||||
processor = Processor.convert_from_transformers(model_name_or_path,
|
||||
revision=revision,
|
||||
@ -198,6 +200,7 @@ class Inferencer:
|
||||
tokenizer_class=tokenizer_class,
|
||||
tokenizer_args=tokenizer_args,
|
||||
use_fast=use_fast,
|
||||
use_auth_token=use_auth_token,
|
||||
**kwargs)
|
||||
|
||||
# override processor attributes loaded from config or HF with inferencer params
|
||||
|
||||
@ -467,7 +467,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
|
||||
@classmethod
|
||||
def convert_from_transformers(cls, model_name_or_path: Union[str, Path], device: str, revision: Optional[str] = None,
|
||||
task_type: Optional[str] = None, processor: Optional[Processor] = None, **kwargs):
|
||||
task_type: Optional[str] = None, processor: Optional[Processor] = None, use_auth_token: Optional[Union[bool, str]] = None, **kwargs):
|
||||
"""
|
||||
Load a (downstream) model from huggingface's transformers format. Use cases:
|
||||
- continue training in Haystack (e.g. take a squad QA model and fine-tune on your own data)
|
||||
@ -494,6 +494,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
device=device,
|
||||
task_type=task_type,
|
||||
processor=processor,
|
||||
use_auth_token=use_auth_token,
|
||||
**kwargs)
|
||||
|
||||
|
||||
|
||||
@ -1149,11 +1149,11 @@ class DPRQuestionEncoder(LanguageModel):
|
||||
dpr_question_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(str(pretrained_model_name_or_path)).model
|
||||
dpr_question_encoder.language = dpr_question_encoder.model.config.language
|
||||
else:
|
||||
original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, use_auth_token=use_auth_token)
|
||||
if original_model_config.model_type == "dpr":
|
||||
# "pretrained dpr model": load existing pretrained DPRQuestionEncoder model
|
||||
dpr_question_encoder.model = transformers.DPRQuestionEncoder.from_pretrained(
|
||||
str(pretrained_model_name_or_path), **kwargs)
|
||||
str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **kwargs)
|
||||
else:
|
||||
# "from scratch": load weights from different architecture (e.g. bert) into DPRQuestionEncoder
|
||||
# but keep config values from original architecture
|
||||
@ -1165,7 +1165,7 @@ class DPRQuestionEncoder(LanguageModel):
|
||||
original_config_dict.update(kwargs)
|
||||
dpr_question_encoder.model = transformers.DPRQuestionEncoder(config=transformers.DPRConfig(**original_config_dict))
|
||||
dpr_question_encoder.model.base_model.bert_model = AutoModel.from_pretrained(
|
||||
str(pretrained_model_name_or_path), **original_config_dict)
|
||||
str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **original_config_dict)
|
||||
dpr_question_encoder.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
|
||||
|
||||
return dpr_question_encoder
|
||||
@ -1281,16 +1281,16 @@ class DPRContextEncoder(LanguageModel):
|
||||
dpr_context_encoder.model = transformers.DPRContextEncoder(config=transformers.DPRConfig(**original_config_dict))
|
||||
language_model_class = cls.get_language_model_class(haystack_lm_config, **kwargs)
|
||||
dpr_context_encoder.model.base_model.bert_model = cls.subclasses[language_model_class].load(
|
||||
str(pretrained_model_name_or_path)).model
|
||||
str(pretrained_model_name_or_path), use_auth_token=use_auth_token).model
|
||||
dpr_context_encoder.language = dpr_context_encoder.model.config.language
|
||||
|
||||
else:
|
||||
# Pytorch-transformer Style
|
||||
original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, use_auth_token=use_auth_token)
|
||||
if original_model_config.model_type == "dpr":
|
||||
# "pretrained dpr model": load existing pretrained DPRContextEncoder model
|
||||
dpr_context_encoder.model = transformers.DPRContextEncoder.from_pretrained(
|
||||
str(pretrained_model_name_or_path), **kwargs)
|
||||
str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **kwargs)
|
||||
else:
|
||||
# "from scratch": load weights from different architecture (e.g. bert) into DPRContextEncoder
|
||||
# but keep config values from original architecture
|
||||
@ -1304,7 +1304,7 @@ class DPRContextEncoder(LanguageModel):
|
||||
dpr_context_encoder.model = transformers.DPRContextEncoder(
|
||||
config=transformers.DPRConfig(**original_config_dict))
|
||||
dpr_context_encoder.model.base_model.bert_model = AutoModel.from_pretrained(
|
||||
str(pretrained_model_name_or_path), **original_config_dict)
|
||||
str(pretrained_model_name_or_path), use_auth_token=use_auth_token, **original_config_dict)
|
||||
dpr_context_encoder.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
|
||||
|
||||
return dpr_context_encoder
|
||||
|
||||
@ -55,7 +55,7 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
retriever.embedding_model, revision=retriever.model_version, task_type="embeddings",
|
||||
extraction_strategy=retriever.pooling_strategy,
|
||||
extraction_layer=retriever.emb_extraction_layer, gpu=retriever.use_gpu,
|
||||
batch_size=4, max_seq_len=512, num_processes=0
|
||||
batch_size=4, max_seq_len=512, num_processes=0,use_auth_token=retriever.use_auth_token
|
||||
)
|
||||
# Check that document_store has the right similarity function
|
||||
similarity = retriever.document_store.similarity
|
||||
|
||||
@ -52,7 +52,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
similarity_function: str = "dot_product",
|
||||
global_loss_buffer_size: int = 150000,
|
||||
progress_bar: bool = True,
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None,
|
||||
use_auth_token: Optional[Union[str,bool]] = None,
|
||||
):
|
||||
"""
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
@ -101,6 +102,9 @@ class DensePassageRetriever(BaseRetriever):
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
:param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
@ -109,7 +113,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
model_version=model_version, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage,
|
||||
top_k=top_k, use_gpu=use_gpu, batch_size=batch_size, embed_title=embed_title,
|
||||
use_fast_tokenizers=use_fast_tokenizers, infer_tokenizer_classes=infer_tokenizer_classes,
|
||||
similarity_function=similarity_function, progress_bar=progress_bar, devices=devices
|
||||
similarity_function=similarity_function, progress_bar=progress_bar, devices=devices, use_auth_token=use_auth_token
|
||||
)
|
||||
|
||||
if devices is not None:
|
||||
@ -148,18 +152,22 @@ class DensePassageRetriever(BaseRetriever):
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
use_fast=use_fast_tokenizers,
|
||||
tokenizer_class=tokenizers_default_classes["query"])
|
||||
tokenizer_class=tokenizers_default_classes["query"],
|
||||
use_auth_token=use_auth_token)
|
||||
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
|
||||
revision=model_version,
|
||||
language_model_class="DPRQuestionEncoder")
|
||||
language_model_class="DPRQuestionEncoder",
|
||||
use_auth_token=use_auth_token)
|
||||
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
use_fast=use_fast_tokenizers,
|
||||
tokenizer_class=tokenizers_default_classes["passage"])
|
||||
tokenizer_class=tokenizers_default_classes["passage"],
|
||||
use_auth_token=use_auth_token)
|
||||
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||
revision=model_version,
|
||||
language_model_class="DPRContextEncoder")
|
||||
language_model_class="DPRContextEncoder",
|
||||
use_auth_token=use_auth_token)
|
||||
|
||||
self.processor = TextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
|
||||
passage_tokenizer=self.passage_tokenizer,
|
||||
@ -487,7 +495,8 @@ class TableTextRetriever(BaseRetriever):
|
||||
similarity_function: str = "dot_product",
|
||||
global_loss_buffer_size: int = 150000,
|
||||
progress_bar: bool = True,
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None,
|
||||
use_auth_token: Optional[Union[str,bool]] = None
|
||||
):
|
||||
"""
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
@ -522,6 +531,9 @@ class TableTextRetriever(BaseRetriever):
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
:param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
@ -531,7 +543,7 @@ class TableTextRetriever(BaseRetriever):
|
||||
max_seq_len_table=max_seq_len_table, top_k=top_k, use_gpu=use_gpu, batch_size=batch_size,
|
||||
embed_meta_fields=embed_meta_fields, use_fast_tokenizers=use_fast_tokenizers,
|
||||
infer_tokenizer_classes=infer_tokenizer_classes, similarity_function=similarity_function,
|
||||
progress_bar=progress_bar, devices=devices
|
||||
progress_bar=progress_bar, devices=devices, use_auth_token=use_auth_token
|
||||
)
|
||||
|
||||
if devices is not None:
|
||||
@ -573,26 +585,32 @@ class TableTextRetriever(BaseRetriever):
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
use_fast=use_fast_tokenizers,
|
||||
tokenizer_class=tokenizers_default_classes["query"])
|
||||
tokenizer_class=tokenizers_default_classes["query"],
|
||||
use_auth_token=use_auth_token)
|
||||
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
|
||||
revision=model_version,
|
||||
language_model_class="DPRQuestionEncoder")
|
||||
language_model_class="DPRQuestionEncoder",
|
||||
use_auth_token=use_auth_token)
|
||||
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
use_fast=use_fast_tokenizers,
|
||||
tokenizer_class=tokenizers_default_classes["passage"])
|
||||
tokenizer_class=tokenizers_default_classes["passage"],
|
||||
use_auth_token=use_auth_token)
|
||||
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
|
||||
revision=model_version,
|
||||
language_model_class="DPRContextEncoder")
|
||||
language_model_class="DPRContextEncoder",
|
||||
use_auth_token=use_auth_token)
|
||||
self.table_tokenizer = Tokenizer.load(pretrained_model_name_or_path=table_embedding_model,
|
||||
revision=model_version,
|
||||
do_lower_case=True,
|
||||
use_fast=use_fast_tokenizers,
|
||||
tokenizer_class=tokenizers_default_classes["table"])
|
||||
tokenizer_class=tokenizers_default_classes["table"],
|
||||
use_auth_token=use_auth_token)
|
||||
self.table_encoder = LanguageModel.load(pretrained_model_name_or_path=table_embedding_model,
|
||||
revision=model_version,
|
||||
language_model_class="DPRContextEncoder")
|
||||
language_model_class="DPRContextEncoder",
|
||||
use_auth_token=use_auth_token)
|
||||
|
||||
self.processor = TableTextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
|
||||
passage_tokenizer=self.passage_tokenizer,
|
||||
@ -942,7 +960,8 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
emb_extraction_layer: int = -1,
|
||||
top_k: int = 10,
|
||||
progress_bar: bool = True,
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None,
|
||||
use_auth_token: Optional[Union[str,bool]] = None
|
||||
):
|
||||
"""
|
||||
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
||||
@ -966,7 +985,10 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
:param top_k: How many documents to return per query.
|
||||
:param progress_bar: If true displays progress bar during embedding.
|
||||
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
:param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
@ -989,6 +1011,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
self.emb_extraction_layer = emb_extraction_layer
|
||||
self.top_k = top_k
|
||||
self.progress_bar = progress_bar
|
||||
self.use_auth_token = use_auth_token
|
||||
|
||||
logger.info(f"Init retriever using embeddings of model {embedding_model}")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user