mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-30 19:36:23 +00:00
Set provider parameter when instantiating onnxruntime.InferenceSession (#1976)
* Set provider parameter when instantiating onnxruntime.InferenceSession fixes #1973 * Change device type to torch.device * set type annotation of device to torch.device everywhere * Apply Black * Change types of device and devices params across the codebase * Update Documentation & Code Style * Add type: ignore in the right location * Update Documentation & Code Style * Add type: ignore * feedback * Update Documentation & Code Style * feedback 2 * Fix convert_to_transformers * Fix syntax error * Update Documentation & Code Style * Consider augment and load_glove user-facing as well * Update Documentation & Code Style * Fix mypy * Update Documentation & Code Style Co-authored-by: Julian Risch <julian.risch@deepset.ai> Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai>
This commit is contained in:
parent
851fe1cf07
commit
3b2001e66f
@ -92,7 +92,7 @@ p.add_node(component=ranker, name="Ranker", inputs=["ESRetriever"])
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(model_name_or_path: Union[str, Path], model_version: Optional[str] = None, top_k: int = 10, use_gpu: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None)
|
||||
def __init__(model_name_or_path: Union[str, Path], model_version: Optional[str] = None, top_k: int = 10, use_gpu: bool = True, devices: Optional[List[Union[str, torch.device]]] = None)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -103,7 +103,10 @@ See https://huggingface.co/cross-encoder for full list of available models
|
||||
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
- `top_k`: The maximum number of documents to return
|
||||
- `use_gpu`: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
|
||||
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
|
||||
The strings will be converted into pytorch devices, so use the string notation described here:
|
||||
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
|
||||
(e.g. ["cuda:0"]).
|
||||
|
||||
<a id="sentence_transformers.SentenceTransformersRanker.predict_batch"></a>
|
||||
|
||||
|
@ -398,7 +398,7 @@ Dict containing query and answers
|
||||
#### eval\_on\_file
|
||||
|
||||
```python
|
||||
def eval_on_file(data_dir: str, test_filename: str, device: Optional[str] = None)
|
||||
def eval_on_file(data_dir: Union[Path, str], test_filename: str, device: Optional[Union[str, torch.device]] = None)
|
||||
```
|
||||
|
||||
Performs evaluation on a SQuAD-formatted file.
|
||||
@ -410,16 +410,18 @@ Returns a dict containing the following metrics:
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `data_dir` (`Path or str`): The directory in which the test set can be found
|
||||
- `test_filename` (`str`): The name of the file containing the test data in SQuAD format.
|
||||
- `device` (`str`): The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
|
||||
- `data_dir`: The directory in which the test set can be found
|
||||
- `test_filename`: The name of the file containing the test data in SQuAD format.
|
||||
- `device`: The device on which the tensors should be processed.
|
||||
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
|
||||
or use the Reader's device by default.
|
||||
|
||||
<a id="farm.FARMReader.eval"></a>
|
||||
|
||||
#### eval
|
||||
|
||||
```python
|
||||
def eval(document_store: BaseDocumentStore, device: Optional[str] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold-label", calibrate_conf_scores: bool = False)
|
||||
def eval(document_store: BaseDocumentStore, device: Optional[Union[str, torch.device]] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold-label", calibrate_conf_scores: bool = False)
|
||||
```
|
||||
|
||||
Performs evaluation on evaluation documents in the DocumentStore.
|
||||
@ -432,7 +434,9 @@ Returns a dict containing the following metrics:
|
||||
**Arguments**:
|
||||
|
||||
- `document_store`: DocumentStore containing the evaluation documents
|
||||
- `device`: The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
|
||||
- `device`: The device on which the tensors should be processed.
|
||||
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
|
||||
or use the Reader's device by default.
|
||||
- `label_index`: Index/Table name where labeled questions are stored
|
||||
- `doc_index`: Index/Table name where documents that are used for evaluation are stored
|
||||
- `label_origin`: Field name where the gold labels are stored
|
||||
@ -443,7 +447,7 @@ Returns a dict containing the following metrics:
|
||||
#### calibrate\_confidence\_scores
|
||||
|
||||
```python
|
||||
def calibrate_confidence_scores(document_store: BaseDocumentStore, device: Optional[str] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label")
|
||||
def calibrate_confidence_scores(document_store: BaseDocumentStore, device: Optional[Union[str, torch.device]] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label")
|
||||
```
|
||||
|
||||
Calibrates confidence scores on evaluation documents in the DocumentStore.
|
||||
@ -451,7 +455,9 @@ Calibrates confidence scores on evaluation documents in the DocumentStore.
|
||||
**Arguments**:
|
||||
|
||||
- `document_store`: DocumentStore containing the evaluation documents
|
||||
- `device`: The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
|
||||
- `device`: The device on which the tensors should be processed.
|
||||
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
|
||||
or use the Reader's device by default.
|
||||
- `label_index`: Index/Table name where labeled questions are stored
|
||||
- `doc_index`: Index/Table name where documents that are used for evaluation are stored
|
||||
- `label_origin`: Field name where the gold labels are stored
|
||||
|
@ -312,7 +312,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
|
||||
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
|
||||
```
|
||||
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
@ -362,8 +362,11 @@ Options: `dot_product` (Default) or `cosine`
|
||||
Increase if errors like "encoded data exceeds max_size ..." come up
|
||||
- `progress_bar`: Whether to show a tqdm progress bar or not.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
|
||||
These strings will be converted into pytorch devices, so use the string notation described here:
|
||||
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
|
||||
(e.g. ["cuda:0"]). Note: as multi-GPU training is currently not implemented for DPR, training
|
||||
will only use the first device provided in this list.
|
||||
- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
@ -520,7 +523,7 @@ Kostić, Bogdan, et al. (2021): "Multi-modal Retrieval of Tables and Texts Using
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
|
||||
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
|
||||
```
|
||||
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
@ -556,8 +559,11 @@ Options: `dot_product` (Default) or `cosine`
|
||||
Increase if errors like "encoded data exceeds max_size ..." come up
|
||||
- `progress_bar`: Whether to show a tqdm progress bar or not.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
|
||||
These strings will be converted into pytorch devices, so use the string notation described here:
|
||||
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
|
||||
(e.g. ["cuda:0"]). Note: as multi-GPU training is currently not implemented for TableTextRetriever,
|
||||
training will only use the first device provided in this list.
|
||||
- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
@ -695,7 +701,7 @@ class EmbeddingRetriever(BaseRetriever)
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, batch_size: int = 32, max_seq_len: int = 512, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
|
||||
def __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, batch_size: int = 32, max_seq_len: int = 512, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -721,8 +727,11 @@ Options:
|
||||
Default: -1 (very last layer).
|
||||
- `top_k`: How many documents to return per query.
|
||||
- `progress_bar`: If true displays progress bar during embedding.
|
||||
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
|
||||
These strings will be converted into pytorch devices, so use the string notation described here:
|
||||
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
|
||||
(e.g. ["cuda:0"]). Note: As multi-GPU training is currently not implemented for EmbeddingRetriever,
|
||||
training will only use the first device provided in this list.
|
||||
- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
|
@ -1,11 +1,13 @@
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForQuestionAnswering
|
||||
|
||||
from haystack.modeling.model import adaptive_model as am
|
||||
from haystack.modeling.model.language_model import LanguageModel
|
||||
from haystack.modeling.model.prediction_head import QuestionAnsweringHead
|
||||
from haystack.modeling.data_handler.processor import Processor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -46,10 +48,10 @@ class Converter:
|
||||
@staticmethod
|
||||
def convert_from_transformers(
|
||||
model_name_or_path,
|
||||
device,
|
||||
revision=None,
|
||||
task_type=None,
|
||||
processor=None,
|
||||
device: Union[str, torch.device],
|
||||
revision: str = None,
|
||||
task_type: str = "question_answering",
|
||||
processor: Processor = None,
|
||||
use_auth_token: Union[bool, str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -65,14 +67,10 @@ class Converter:
|
||||
- deepset/bert-large-uncased-whole-word-masking-squad2
|
||||
|
||||
See https://huggingface.co/models for full list
|
||||
:param device: "cpu" or "cuda"
|
||||
:param device: torch.device("cpu") or torch.device("cuda")
|
||||
:param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:type revision: str
|
||||
:param task_type: One of :
|
||||
- 'question_answering'
|
||||
More tasks coming soon ...
|
||||
:param processor: populates prediction head with information coming from tasks
|
||||
:type processor: Processor
|
||||
Right now accepts only 'question_answering'.
|
||||
:param processor: populates prediction head with information coming from tasks.
|
||||
:return: AdaptiveModel
|
||||
"""
|
||||
|
||||
|
@ -785,7 +785,7 @@ class DistillationDataSilo(DataSilo):
|
||||
self,
|
||||
teacher_model: "FARMReader",
|
||||
teacher_batch_size: int,
|
||||
device: str,
|
||||
device: torch.device,
|
||||
processor: Processor,
|
||||
batch_size: int,
|
||||
eval_batch_size: Optional[int] = None,
|
||||
|
@ -20,11 +20,11 @@ class Evaluator:
|
||||
Handles evaluation of a given model over a specified dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, data_loader: torch.utils.data.DataLoader, tasks, device: str, report: bool = True):
|
||||
def __init__(self, data_loader: torch.utils.data.DataLoader, tasks, device: torch.device, report: bool = True):
|
||||
"""
|
||||
:param data_loader: The PyTorch DataLoader that will return batches of data from the evaluation dataset
|
||||
:param tesks:
|
||||
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda".
|
||||
:param device: The device on which the tensors should be processed. Choose from torch.device("cpu") and torch.device("cuda").
|
||||
:param report: Whether an eval report should be generated (e.g. classification report per class).
|
||||
"""
|
||||
self.data_loader = data_loader
|
||||
|
@ -128,7 +128,7 @@ class Inferencer:
|
||||
use_fast: bool = True,
|
||||
tokenizer_args: Dict = None,
|
||||
multithreading_rust: bool = True,
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None,
|
||||
devices: Optional[List[torch.device]] = None,
|
||||
use_auth_token: Union[bool, str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -169,7 +169,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
prediction_heads: List[PredictionHead],
|
||||
embeds_dropout_prob: float,
|
||||
lm_output_types: Union[str, List[str]],
|
||||
device: str,
|
||||
device: torch.device,
|
||||
loss_aggregation_fn: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
@ -182,7 +182,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
"per_sequence", a single embedding will be extracted to represent the full
|
||||
input sequence. Can either be a single string, or a list of strings,
|
||||
one for each prediction head.
|
||||
:param device: The device on which this model will operate. Either "cpu" or "cuda".
|
||||
:param device: The device on which this model will operate. Either torch.device("cpu") or torch.device("cuda").
|
||||
:param loss_aggregation_fn: Function to aggregate the loss of multiple prediction heads.
|
||||
Input: loss_per_head (list of tensors), global_step (int), batch (dict)
|
||||
Output: aggregated loss (tensor)
|
||||
@ -258,13 +258,13 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
# Need to save config and pipeline
|
||||
|
||||
@classmethod
|
||||
def load( # type: ignore
|
||||
def load( # type: ignore
|
||||
cls,
|
||||
load_dir: Union[str, Path], # type: ignore
|
||||
device: str, # type: ignore
|
||||
strict: bool = True, # type: ignore
|
||||
lm_name: Optional[str] = None, # type: ignore
|
||||
processor: Optional[Processor] = None, # type: ignore
|
||||
load_dir: Union[str, Path],
|
||||
device: Union[str, torch.device],
|
||||
strict: bool = True,
|
||||
lm_name: Optional[str] = None,
|
||||
processor: Optional[Processor] = None,
|
||||
):
|
||||
"""
|
||||
Loads an AdaptiveModel from a directory. The directory must contain:
|
||||
@ -277,12 +277,13 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
* vocab.txt vocab file for language model, turning text to Wordpiece Tokens
|
||||
|
||||
:param load_dir: Location where the AdaptiveModel is stored.
|
||||
:param device: To which device we want to sent the model, either cpu or cuda.
|
||||
:param device: To which device we want to sent the model, either torch.device("cpu") or torch.device("cuda").
|
||||
:param lm_name: The name to assign to the loaded language model.
|
||||
:param strict: Whether to strictly enforce that the keys loaded from saved model match the ones in
|
||||
the PredictionHead (see torch.nn.module.load_state_dict()).
|
||||
:param processor: Processor to populate prediction head with information coming from tasks.
|
||||
"""
|
||||
device = torch.device(device)
|
||||
# Language Model
|
||||
if lm_name:
|
||||
language_model = LanguageModel.load(load_dir, haystack_lm_name=lm_name)
|
||||
@ -489,9 +490,9 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
def convert_from_transformers(
|
||||
cls,
|
||||
model_name_or_path: Union[str, Path],
|
||||
device: str,
|
||||
device: torch.device,
|
||||
revision: Optional[str] = None,
|
||||
task_type: Optional[str] = None,
|
||||
task_type: str = "question_answering",
|
||||
processor: Optional[Processor] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
**kwargs,
|
||||
@ -509,12 +510,9 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
|
||||
See https://huggingface.co/models for full list
|
||||
:param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param device: "cpu" or "cuda"
|
||||
:param task_type: One of :
|
||||
- 'question_answering'
|
||||
More tasks coming soon ...
|
||||
:param device: On which hardware the conversion should take place. Choose from torch.device("cpu") or torch.device("cuda")
|
||||
:param task_type: 'question_answering'. More tasks coming soon ...
|
||||
:param processor: Processor to populate prediction head with information coming from tasks.
|
||||
:type processor: Processor
|
||||
:return: AdaptiveModel
|
||||
"""
|
||||
return conv.Converter.convert_from_transformers(
|
||||
@ -570,7 +568,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
tokenizer_name_or_path=model_name, task_type=task_type, max_seq_len=256, doc_stride=128, use_fast=True
|
||||
)
|
||||
processor.save(output_path)
|
||||
model = AdaptiveModel.convert_from_transformers(model_name, device="cpu", task_type=task_type)
|
||||
model = AdaptiveModel.convert_from_transformers(model_name, device=torch.device("cpu"), task_type=task_type)
|
||||
model.save(output_path)
|
||||
os.remove(output_path / "language_model.bin") # remove the actual PyTorch model(only configs are required)
|
||||
|
||||
@ -617,14 +615,14 @@ class ONNXAdaptiveModel(BaseAdaptiveModel):
|
||||
language_model_class: str,
|
||||
language: str,
|
||||
prediction_heads: List[PredictionHead],
|
||||
device: str,
|
||||
device: torch.device,
|
||||
):
|
||||
"""
|
||||
:param onnx_session: ? # TODO
|
||||
:param language_model_class: Class of LanguageModel
|
||||
:param langauge: Language the model is trained for.
|
||||
:param prediction_heads: A list of models that take embeddings and return logits for a given task.
|
||||
:param device: The device on which this model will operate. Either "cpu" or "cuda".
|
||||
:param device: The device on which this model will operate. Either torch.device("cpu") or torch.device("cuda").
|
||||
"""
|
||||
import onnxruntime
|
||||
|
||||
@ -642,13 +640,14 @@ class ONNXAdaptiveModel(BaseAdaptiveModel):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_dir: Union[str, Path], device: str, **kwargs): # type: ignore
|
||||
def load(cls, load_dir: Union[str, Path], device: Union[str, torch.device], **kwargs): # type: ignore
|
||||
"""
|
||||
Loads an ONNXAdaptiveModel from a directory.
|
||||
|
||||
:param load_dir: Location where the ONNXAdaptiveModel is stored.
|
||||
:param device: The device on which this model will operate. Either "cpu" or "cuda".
|
||||
:param device: The device on which this model will operate. Either torch.device("cpu") or torch.device("cuda").
|
||||
"""
|
||||
device = torch.device(device)
|
||||
load_dir = Path(load_dir)
|
||||
import onnxruntime
|
||||
|
||||
@ -657,7 +656,11 @@ class ONNXAdaptiveModel(BaseAdaptiveModel):
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
||||
# Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs.
|
||||
sess_options.intra_op_num_threads = multiprocessing.cpu_count()
|
||||
onnx_session = onnxruntime.InferenceSession(str(load_dir / "model.onnx"), sess_options)
|
||||
|
||||
providers = kwargs.get(
|
||||
"providers", ["CPUExecutionProvider"] if device.type == "cpu" else ["CUDAExecutionProvider"]
|
||||
)
|
||||
onnx_session = onnxruntime.InferenceSession(str(load_dir / "model.onnx"), sess_options, providers=providers)
|
||||
|
||||
# Prediction heads
|
||||
_, ph_config_files = cls._get_prediction_head_files(load_dir, strict=False)
|
||||
|
@ -36,7 +36,7 @@ class BiAdaptiveModel(nn.Module):
|
||||
language_model2: LanguageModel,
|
||||
prediction_heads: List[PredictionHead],
|
||||
embeds_dropout_prob: float = 0.1,
|
||||
device: str = "cuda",
|
||||
device: torch.device = torch.device("cuda"),
|
||||
lm1_output_types: Union[str, List[str]] = ["per_sequence"],
|
||||
lm2_output_types: Union[str, List[str]] = ["per_sequence"],
|
||||
loss_aggregation_fn: Optional[Callable] = None,
|
||||
@ -57,7 +57,7 @@ class BiAdaptiveModel(nn.Module):
|
||||
"per_sequence", a single embedding will be extracted to represent the full
|
||||
input sequence. Can either be a single string, or a list of strings,
|
||||
one for each prediction head.
|
||||
:param device: The device on which this model will operate. Either "cpu" or "cuda".
|
||||
:param device: The device on which this model will operate. Either torch.device("cpu") or torch.device("cuda").
|
||||
:param loss_aggregation_fn: Function to aggregate the loss of multiple prediction heads.
|
||||
Input: loss_per_head (list of tensors), global_step (int), batch (dict)
|
||||
Output: aggregated loss (tensor)
|
||||
@ -108,7 +108,7 @@ class BiAdaptiveModel(nn.Module):
|
||||
def load(
|
||||
cls,
|
||||
load_dir: Path,
|
||||
device: str,
|
||||
device: torch.device,
|
||||
strict: bool = False,
|
||||
lm1_name: str = "lm1",
|
||||
lm2_name: str = "lm2",
|
||||
@ -130,7 +130,7 @@ class BiAdaptiveModel(nn.Module):
|
||||
* special_tokens_map.json
|
||||
|
||||
:param load_dir: Location where adaptive model is stored.
|
||||
:param device: To which device we want to sent the model, either cpu or cuda.
|
||||
:param device: To which device we want to sent the model, either torch.device("cpu") or torch.device("cuda").
|
||||
:param lm1_name: The name to assign to the first loaded language model (for encoding queries).
|
||||
:param lm2_name: The name to assign to the second loaded language model (for encoding context/passages).
|
||||
:param strict: Whether to strictly enforce that the keys loaded from saved model match the ones in
|
||||
@ -432,8 +432,8 @@ class BiAdaptiveModel(nn.Module):
|
||||
cls,
|
||||
model_name_or_path1: Union[str, Path],
|
||||
model_name_or_path2: Union[str, Path],
|
||||
device: str,
|
||||
task_type: str,
|
||||
device: torch.device,
|
||||
task_type: str = "text_similarity",
|
||||
processor: Optional[Processor] = None,
|
||||
similarity_function: str = "dot_product",
|
||||
):
|
||||
@ -451,9 +451,8 @@ class BiAdaptiveModel(nn.Module):
|
||||
Exemplary public names:
|
||||
- facebook/dpr-ctx_encoder-single-nq-base
|
||||
- deepset/bert-large-uncased-whole-word-masking-squad2
|
||||
:param device: "cpu" or "cuda"
|
||||
:param task_type: 'text_similarity'
|
||||
More tasks coming soon ...
|
||||
:param device: On which hardware the conversion is going to run on. Either torch.device("cpu") or torch.device("cuda")
|
||||
:param task_type: 'text_similarity' More tasks coming soon ...
|
||||
:param processor: populates prediction head with information coming from tasks
|
||||
:type processor: Processor
|
||||
:return: AdaptiveModel
|
||||
|
@ -1,5 +1,5 @@
|
||||
# TODO analyse if this optimization is needed or whether we can use HF transformers code
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
@ -73,7 +73,7 @@ def initialize_optimizer(
|
||||
model: AdaptiveModel,
|
||||
n_batches: int,
|
||||
n_epochs: int,
|
||||
device,
|
||||
device: torch.device,
|
||||
learning_rate: float,
|
||||
optimizer_opts: Dict[Any, Any] = None,
|
||||
schedule_opts: Dict[Any, Any] = None,
|
||||
@ -90,7 +90,7 @@ def initialize_optimizer(
|
||||
:param model: model to optimize (e.g. trimming weights to fp16 / mixed precision)
|
||||
:param n_batches: number of batches for training
|
||||
:param n_epochs: number of epochs for training
|
||||
:param device:
|
||||
:param device: Which hardware will be used by the optimizer. Either torch.device("cpu") or torch.device("cuda").
|
||||
:param learning_rate: Learning rate
|
||||
:param optimizer_opts: Dict to customize the optimizer. Choose any optimizer available from torch.optim, apex.optimizers or
|
||||
transformers.optimization by supplying the class name and the parameters for the constructor.
|
||||
@ -295,14 +295,20 @@ def get_scheduler(optimizer, opts):
|
||||
return scheduler
|
||||
|
||||
|
||||
def optimize_model(model, device, local_rank, optimizer=None, distributed=False, use_amp=None):
|
||||
def optimize_model(
|
||||
model: "AdaptiveModel",
|
||||
device: torch.device,
|
||||
local_rank: int,
|
||||
optimizer=None,
|
||||
distributed: Optional[bool] = False,
|
||||
use_amp: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Wraps MultiGPU or distributed usage around a model
|
||||
No support for ONNX models
|
||||
|
||||
:param model: model to optimize (e.g. trimming weights to fp16 / mixed precision)
|
||||
:type model: AdaptiveModel
|
||||
:param device: either gpu or cpu, get the device from initialize_device_settings()
|
||||
:param device: either torch.device("cpu") or torch.device("cuda"). Get the device from `initialize_device_settings()`
|
||||
:param distributed: Whether training on distributed machines
|
||||
:param local_rank: rank of the machine in a distributed setting
|
||||
:param use_amp: Optimization level of nvidia's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
|
@ -43,7 +43,7 @@ class TriAdaptiveModel(nn.Module):
|
||||
language_model3: LanguageModel,
|
||||
prediction_heads: List[PredictionHead],
|
||||
embeds_dropout_prob: float = 0.1,
|
||||
device: str = "cuda",
|
||||
device: torch.device = torch.device("cuda"),
|
||||
lm1_output_types: Union[str, List[str]] = ["per_sequence"],
|
||||
lm2_output_types: Union[str, List[str]] = ["per_sequence"],
|
||||
lm3_output_types: Union[str, List[str]] = ["per_sequence"],
|
||||
@ -71,7 +71,7 @@ class TriAdaptiveModel(nn.Module):
|
||||
"per_sequence", a single embedding will be extracted to represent the full
|
||||
input sequence. Can either be a single string, or a list of strings,
|
||||
one for each prediction head.
|
||||
:param device: The device on which this model will operate. Either "cpu" or "cuda".
|
||||
:param device: The device on which this model will operate. Either torch.device("cpu") or torch.device("cuda").
|
||||
:param loss_aggregation_fn: Function to aggregate the loss of multiple prediction heads.
|
||||
Input: loss_per_head (list of tensors), global_step (int), batch (dict)
|
||||
Output: aggregated loss (tensor)
|
||||
@ -129,7 +129,7 @@ class TriAdaptiveModel(nn.Module):
|
||||
def load(
|
||||
cls,
|
||||
load_dir: Path,
|
||||
device: str,
|
||||
device: torch.device,
|
||||
strict: bool = False,
|
||||
lm1_name: str = "lm1",
|
||||
lm2_name: str = "lm2",
|
||||
@ -155,7 +155,7 @@ class TriAdaptiveModel(nn.Module):
|
||||
* special_tokens_map.json
|
||||
|
||||
:param load_dir: Location where the TriAdaptiveModel is stored.
|
||||
:param device: To which device we want to sent the model, either cpu or cuda.
|
||||
:param device: To which device we want to sent the model, either torch.device("cpu") or torch.device("cuda").
|
||||
:param lm1_name: The name to assign to the first loaded language model (for encoding queries).
|
||||
:param lm2_name: The name to assign to the second loaded language model (for encoding context/passages).
|
||||
:param lm3_name: The name to assign to the second loaded language model (for encoding tables).
|
||||
|
@ -126,7 +126,7 @@ class Trainer:
|
||||
data_silo: DataSilo,
|
||||
epochs: int,
|
||||
n_gpu: int,
|
||||
device,
|
||||
device: torch.device,
|
||||
lr_schedule=None,
|
||||
evaluate_every: int = 100,
|
||||
eval_report: bool = True,
|
||||
@ -152,7 +152,7 @@ class Trainer:
|
||||
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
|
||||
:param epochs: How many times the training procedure will loop through the train dataset
|
||||
:param n_gpu: The number of gpus available for training and evaluation.
|
||||
:param device: The device on which the train, dev and test tensors should be hosted. Choose from "cpu" and "cuda".
|
||||
:param device: The device on which the train, dev and test tensors should be hosted. Choose from torch.device("cpu") and torch.device("cuda").
|
||||
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
|
||||
:param evaluate_every: Perform dev set evaluation after this many steps of training.
|
||||
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
|
||||
@ -660,7 +660,7 @@ class DistillationTrainer(Trainer):
|
||||
data_silo: DistillationDataSilo,
|
||||
epochs: int,
|
||||
n_gpu: int,
|
||||
device: str,
|
||||
device: torch.device,
|
||||
lr_schedule: Optional["_LRScheduler"] = None,
|
||||
evaluate_every: int = 100,
|
||||
eval_report: bool = True,
|
||||
@ -691,7 +691,7 @@ class DistillationTrainer(Trainer):
|
||||
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
|
||||
:param epochs: How many times the training procedure will loop through the train dataset
|
||||
:param n_gpu: The number of gpus available for training and evaluation.
|
||||
:param device: The device on which the train, dev and test tensors should be hosted. Choose from "cpu" and "cuda".
|
||||
:param device: The device on which the train, dev and test tensors should be hosted. Choose from torch.device("cpu") and torch.device("cuda").
|
||||
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
|
||||
:param evaluate_every: Perform dev set evaluation after this many steps of training.
|
||||
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
|
||||
@ -833,7 +833,7 @@ class TinyBERTDistillationTrainer(Trainer):
|
||||
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
|
||||
:param epochs: How many times the training procedure will loop through the train dataset
|
||||
:param n_gpu: The number of gpus available for training and evaluation.
|
||||
:param device: The device on which the train, dev and test tensors should be hosted. Choose from "cpu" and "cuda".
|
||||
:param device: The device on which the train, dev and test tensors should be hosted. Choose from torch.device("cpu") and torch.device("cuda").
|
||||
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
|
||||
:param evaluate_every: Perform dev set evaluation after this many steps of training.
|
||||
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
|
||||
|
@ -41,7 +41,7 @@ class SentenceTransformersRanker(BaseRanker):
|
||||
model_version: Optional[str] = None,
|
||||
top_k: int = 10,
|
||||
use_gpu: bool = True,
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
):
|
||||
"""
|
||||
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
|
||||
@ -50,16 +50,20 @@ class SentenceTransformersRanker(BaseRanker):
|
||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param top_k: The maximum number of documents to return
|
||||
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
|
||||
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
:param devices: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
|
||||
The strings will be converted into pytorch devices, so use the string notation described here:
|
||||
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
|
||||
(e.g. ["cuda:0"]).
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.top_k = top_k
|
||||
|
||||
if devices is not None:
|
||||
self.devices = devices
|
||||
self.devices = [torch.device(device) for device in devices]
|
||||
else:
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True)
|
||||
|
||||
self.transformer_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
pretrained_model_name_or_path=model_name_or_path, revision=model_version
|
||||
)
|
||||
|
@ -777,7 +777,9 @@ class FARMReader(BaseReader):
|
||||
|
||||
return result
|
||||
|
||||
def eval_on_file(self, data_dir: str, test_filename: str, device: Optional[str] = None):
|
||||
def eval_on_file(
|
||||
self, data_dir: Union[Path, str], test_filename: str, device: Optional[Union[str, torch.device]] = None
|
||||
):
|
||||
"""
|
||||
Performs evaluation on a SQuAD-formatted file.
|
||||
Returns a dict containing the following metrics:
|
||||
@ -786,14 +788,16 @@ class FARMReader(BaseReader):
|
||||
- "top_n_accuracy": Proportion of predicted answers that overlap with correct answer
|
||||
|
||||
:param data_dir: The directory in which the test set can be found
|
||||
:type data_dir: Path or str
|
||||
:param test_filename: The name of the file containing the test data in SQuAD format.
|
||||
:type test_filename: str
|
||||
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
|
||||
:type device: str
|
||||
:param device: The device on which the tensors should be processed.
|
||||
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
|
||||
or use the Reader's device by default.
|
||||
"""
|
||||
if device is None:
|
||||
device = self.devices[0]
|
||||
else:
|
||||
device = torch.device(device)
|
||||
|
||||
eval_processor = SquadProcessor(
|
||||
tokenizer=self.inferencer.processor.tokenizer,
|
||||
max_seq_len=self.inferencer.processor.max_seq_len,
|
||||
@ -822,7 +826,7 @@ class FARMReader(BaseReader):
|
||||
def eval(
|
||||
self,
|
||||
document_store: BaseDocumentStore,
|
||||
device: Optional[str] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
label_index: str = "label",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold-label",
|
||||
@ -836,7 +840,9 @@ class FARMReader(BaseReader):
|
||||
- "top_n_accuracy": Proportion of predicted answers that overlap with correct answer
|
||||
|
||||
:param document_store: DocumentStore containing the evaluation documents
|
||||
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
|
||||
:param device: The device on which the tensors should be processed.
|
||||
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
|
||||
or use the Reader's device by default.
|
||||
:param label_index: Index/Table name where labeled questions are stored
|
||||
:param doc_index: Index/Table name where documents that are used for evaluation are stored
|
||||
:param label_origin: Field name where the gold labels are stored
|
||||
@ -844,6 +850,9 @@ class FARMReader(BaseReader):
|
||||
"""
|
||||
if device is None:
|
||||
device = self.devices[0]
|
||||
else:
|
||||
device = torch.device(device)
|
||||
|
||||
if self.top_k_per_candidate != 4:
|
||||
logger.info(
|
||||
f"Performing Evaluation using top_k_per_candidate = {self.top_k_per_candidate} \n"
|
||||
@ -1012,7 +1021,7 @@ class FARMReader(BaseReader):
|
||||
def calibrate_confidence_scores(
|
||||
self,
|
||||
document_store: BaseDocumentStore,
|
||||
device: Optional[str] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
label_index: str = "label",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
@ -1021,7 +1030,9 @@ class FARMReader(BaseReader):
|
||||
Calibrates confidence scores on evaluation documents in the DocumentStore.
|
||||
|
||||
:param document_store: DocumentStore containing the evaluation documents
|
||||
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
|
||||
:param device: The device on which the tensors should be processed.
|
||||
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
|
||||
or use the Reader's device by default.
|
||||
:param label_index: Index/Table name where labeled questions are stored
|
||||
:param doc_index: Index/Table name where documents that are used for evaluation are stored
|
||||
:param label_origin: Field name where the gold labels are stored
|
||||
|
@ -54,7 +54,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
similarity_function: str = "dot_product",
|
||||
global_loss_buffer_size: int = 150000,
|
||||
progress_bar: bool = True,
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
"""
|
||||
@ -102,8 +102,11 @@ class DensePassageRetriever(BaseRetriever):
|
||||
Increase if errors like "encoded data exceeds max_size ..." come up
|
||||
:param progress_bar: Whether to show a tqdm progress bar or not.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
:param devices: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
|
||||
These strings will be converted into pytorch devices, so use the string notation described here:
|
||||
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
|
||||
(e.g. ["cuda:0"]). Note: as multi-GPU training is currently not implemented for DPR, training
|
||||
will only use the first device provided in this list.
|
||||
:param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
@ -111,7 +114,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
super().__init__()
|
||||
|
||||
if devices is not None:
|
||||
self.devices = devices
|
||||
self.devices = [torch.device(device) for device in devices]
|
||||
else:
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True)
|
||||
|
||||
@ -193,7 +196,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
embeds_dropout_prob=0.1,
|
||||
lm1_output_types=["per_sequence"],
|
||||
lm2_output_types=["per_sequence"],
|
||||
device=str(self.devices[0]),
|
||||
device=self.devices[0],
|
||||
)
|
||||
|
||||
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
|
||||
@ -548,7 +551,7 @@ class TableTextRetriever(BaseRetriever):
|
||||
similarity_function: str = "dot_product",
|
||||
global_loss_buffer_size: int = 150000,
|
||||
progress_bar: bool = True,
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
"""
|
||||
@ -582,8 +585,11 @@ class TableTextRetriever(BaseRetriever):
|
||||
Increase if errors like "encoded data exceeds max_size ..." come up
|
||||
:param progress_bar: Whether to show a tqdm progress bar or not.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
:param devices: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
|
||||
These strings will be converted into pytorch devices, so use the string notation described here:
|
||||
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
|
||||
(e.g. ["cuda:0"]). Note: as multi-GPU training is currently not implemented for TableTextRetriever,
|
||||
training will only use the first device provided in this list.
|
||||
:param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
@ -591,7 +597,7 @@ class TableTextRetriever(BaseRetriever):
|
||||
super().__init__()
|
||||
|
||||
if devices is not None:
|
||||
self.devices = devices
|
||||
self.devices = [torch.device(device) for device in devices]
|
||||
else:
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True)
|
||||
|
||||
@ -699,7 +705,7 @@ class TableTextRetriever(BaseRetriever):
|
||||
lm1_output_types=["per_sequence"],
|
||||
lm2_output_types=["per_sequence"],
|
||||
lm3_output_types=["per_sequence"],
|
||||
device=str(self.devices[0]),
|
||||
device=self.devices[0],
|
||||
)
|
||||
|
||||
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
|
||||
@ -1075,7 +1081,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
emb_extraction_layer: int = -1,
|
||||
top_k: int = 10,
|
||||
progress_bar: bool = True,
|
||||
devices: Optional[List[Union[int, str, torch.device]]] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
"""
|
||||
@ -1101,8 +1107,11 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
Default: -1 (very last layer).
|
||||
:param top_k: How many documents to return per query.
|
||||
:param progress_bar: If true displays progress bar during embedding.
|
||||
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
|
||||
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
|
||||
:param devices: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
|
||||
These strings will be converted into pytorch devices, so use the string notation described here:
|
||||
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
|
||||
(e.g. ["cuda:0"]). Note: As multi-GPU training is currently not implemented for EmbeddingRetriever,
|
||||
training will only use the first device provided in this list.
|
||||
:param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
|
||||
the local token will be used, which must be previously created via `transformer-cli login`.
|
||||
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
@ -1110,7 +1119,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
super().__init__()
|
||||
|
||||
if devices is not None:
|
||||
self.devices = devices
|
||||
self.devices = [torch.device(device) for device in devices]
|
||||
else:
|
||||
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True)
|
||||
|
||||
|
@ -23,6 +23,7 @@ Arguments:
|
||||
model: Huggingface MLM model identifier.
|
||||
"""
|
||||
|
||||
from typing import Tuple, List, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
@ -37,13 +38,14 @@ import argparse
|
||||
import json
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from typing import Tuple, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_glove(
|
||||
glove_path: Path = Path("glove.txt"), vocab_size: int = 100_000, device: str = "cpu:0"
|
||||
glove_path: Path = Path("glove.txt"),
|
||||
vocab_size: int = 100_000,
|
||||
device: Union[str, torch.device] = torch.device("cpu:0"),
|
||||
) -> Tuple[dict, dict, torch.Tensor]:
|
||||
"""Loads the GloVe vectors and returns a mapping from words to their GloVe vector indices and the other way around."""
|
||||
|
||||
@ -112,8 +114,9 @@ def get_replacements(
|
||||
text: str,
|
||||
word_possibilities: int = 20,
|
||||
batch_size: int = 16,
|
||||
device: str = "cpu:0",
|
||||
device: torch.device = torch.device("cpu:0"),
|
||||
) -> List[List[str]]:
|
||||
|
||||
"""Returns a list of possible replacements for each word in the text."""
|
||||
input_ids, words, word_subword_mapping = tokenize_and_extract_words(text, tokenizer)
|
||||
|
||||
@ -179,8 +182,9 @@ def augment(
|
||||
word_possibilities: int = 20,
|
||||
replace_probability: float = 0.4,
|
||||
batch_size: int = 16,
|
||||
device: str = "cpu:0",
|
||||
device: Union[str, torch.device] = torch.device("cpu:0"),
|
||||
) -> List[str]:
|
||||
device = torch.device(device)
|
||||
# returns a list of different augmented versions of the text
|
||||
replacements = get_replacements(
|
||||
glove_word_id_mapping=word_id_mapping,
|
||||
@ -211,16 +215,17 @@ def augment(
|
||||
def augment_squad(
|
||||
squad_path: Path,
|
||||
output_path: Path,
|
||||
glove_path: Path = Path("glove.txt"),
|
||||
model: str = "bert-base-uncased",
|
||||
tokenizer: str = "bert-base-uncased",
|
||||
glove_path: Path = Path("glove.txt"),
|
||||
multiplication_factor: int = 20,
|
||||
word_possibilities: int = 20,
|
||||
replace_probability: float = 0.4,
|
||||
device: str = "cpu:0",
|
||||
device: Union[str, torch.device] = "cpu:0",
|
||||
batch_size: int = 16,
|
||||
):
|
||||
"""Loads a squad dataset, augments the contexts, and saves the result in SQuAD format."""
|
||||
device = torch.device(device)
|
||||
# loading model and tokenizer
|
||||
transformers_model = AutoModelForMaskedLM.from_pretrained(model)
|
||||
transformers_model.to(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user