Set provider parameter when instantiating onnxruntime.InferenceSession (#1976)

* Set provider parameter when instantiating onnxruntime.InferenceSession
fixes #1973

* Change device type to torch.device

* set type annotation of device to torch.device everywhere

* Apply Black

* Change types of device and devices params across the codebase

* Update Documentation & Code Style

* Add type: ignore in the right location

* Update Documentation & Code Style

* Add type: ignore

* feedback

* Update Documentation & Code Style

* feedback 2

* Fix convert_to_transformers

* Fix syntax error

* Update Documentation & Code Style

* Consider augment and load_glove user-facing as well

* Update Documentation & Code Style

* Fix mypy

* Update Documentation & Code Style

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
Co-authored-by: Sara Zan <sara.zanzottera@deepset.ai>
This commit is contained in:
Chris Byrd 2022-03-23 07:08:56 -04:00 committed by GitHub
parent 851fe1cf07
commit 3b2001e66f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 165 additions and 112 deletions

View File

@ -92,7 +92,7 @@ p.add_node(component=ranker, name="Ranker", inputs=["ESRetriever"])
#### \_\_init\_\_
```python
def __init__(model_name_or_path: Union[str, Path], model_version: Optional[str] = None, top_k: int = 10, use_gpu: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None)
def __init__(model_name_or_path: Union[str, Path], model_version: Optional[str] = None, top_k: int = 10, use_gpu: bool = True, devices: Optional[List[Union[str, torch.device]]] = None)
```
**Arguments**:
@ -103,7 +103,10 @@ See https://huggingface.co/cross-encoder for full list of available models
- `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
- `top_k`: The maximum number of documents to return
- `use_gpu`: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
The strings will be converted into pytorch devices, so use the string notation described here:
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
(e.g. ["cuda:0"]).
<a id="sentence_transformers.SentenceTransformersRanker.predict_batch"></a>

View File

@ -398,7 +398,7 @@ Dict containing query and answers
#### eval\_on\_file
```python
def eval_on_file(data_dir: str, test_filename: str, device: Optional[str] = None)
def eval_on_file(data_dir: Union[Path, str], test_filename: str, device: Optional[Union[str, torch.device]] = None)
```
Performs evaluation on a SQuAD-formatted file.
@ -410,16 +410,18 @@ Returns a dict containing the following metrics:
**Arguments**:
- `data_dir` (`Path or str`): The directory in which the test set can be found
- `test_filename` (`str`): The name of the file containing the test data in SQuAD format.
- `device` (`str`): The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
- `data_dir`: The directory in which the test set can be found
- `test_filename`: The name of the file containing the test data in SQuAD format.
- `device`: The device on which the tensors should be processed.
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
or use the Reader's device by default.
<a id="farm.FARMReader.eval"></a>
#### eval
```python
def eval(document_store: BaseDocumentStore, device: Optional[str] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold-label", calibrate_conf_scores: bool = False)
def eval(document_store: BaseDocumentStore, device: Optional[Union[str, torch.device]] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold-label", calibrate_conf_scores: bool = False)
```
Performs evaluation on evaluation documents in the DocumentStore.
@ -432,7 +434,9 @@ Returns a dict containing the following metrics:
**Arguments**:
- `document_store`: DocumentStore containing the evaluation documents
- `device`: The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
- `device`: The device on which the tensors should be processed.
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
or use the Reader's device by default.
- `label_index`: Index/Table name where labeled questions are stored
- `doc_index`: Index/Table name where documents that are used for evaluation are stored
- `label_origin`: Field name where the gold labels are stored
@ -443,7 +447,7 @@ Returns a dict containing the following metrics:
#### calibrate\_confidence\_scores
```python
def calibrate_confidence_scores(document_store: BaseDocumentStore, device: Optional[str] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label")
def calibrate_confidence_scores(document_store: BaseDocumentStore, device: Optional[Union[str, torch.device]] = None, label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label")
```
Calibrates confidence scores on evaluation documents in the DocumentStore.
@ -451,7 +455,9 @@ Calibrates confidence scores on evaluation documents in the DocumentStore.
**Arguments**:
- `document_store`: DocumentStore containing the evaluation documents
- `device`: The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
- `device`: The device on which the tensors should be processed.
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
or use the Reader's device by default.
- `label_index`: Index/Table name where labeled questions are stored
- `doc_index`: Index/Table name where documents that are used for evaluation are stored
- `label_origin`: Field name where the gold labels are stored

View File

@ -312,7 +312,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
#### \_\_init\_\_
```python
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
```
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@ -362,8 +362,11 @@ Options: `dot_product` (Default) or `cosine`
Increase if errors like "encoded data exceeds max_size ..." come up
- `progress_bar`: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
These strings will be converted into pytorch devices, so use the string notation described here:
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
(e.g. ["cuda:0"]). Note: as multi-GPU training is currently not implemented for DPR, training
will only use the first device provided in this list.
- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
the local token will be used, which must be previously created via `transformer-cli login`.
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
@ -520,7 +523,7 @@ Kostić, Bogdan, et al. (2021): "Multi-modal Retrieval of Tables and Texts Using
#### \_\_init\_\_
```python
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
def __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-question_encoder", passage_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-passage_encoder", table_embedding_model: Union[Path, str] = "deepset/bert-small-mm_retrieval-table_encoder", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, max_seq_len_table: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_meta_fields: List[str] = ["name", "section_title", "caption"], use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
```
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
@ -556,8 +559,11 @@ Options: `dot_product` (Default) or `cosine`
Increase if errors like "encoded data exceeds max_size ..." come up
- `progress_bar`: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
These strings will be converted into pytorch devices, so use the string notation described here:
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
(e.g. ["cuda:0"]). Note: as multi-GPU training is currently not implemented for TableTextRetriever,
training will only use the first device provided in this list.
- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
the local token will be used, which must be previously created via `transformer-cli login`.
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
@ -695,7 +701,7 @@ class EmbeddingRetriever(BaseRetriever)
#### \_\_init\_\_
```python
def __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, batch_size: int = 32, max_seq_len: int = 512, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
def __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, batch_size: int = 32, max_seq_len: int = 512, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None)
```
**Arguments**:
@ -721,8 +727,11 @@ Options:
Default: -1 (very last layer).
- `top_k`: How many documents to return per query.
- `progress_bar`: If true displays progress bar during embedding.
- `devices`: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
- `devices`: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
These strings will be converted into pytorch devices, so use the string notation described here:
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
(e.g. ["cuda:0"]). Note: As multi-GPU training is currently not implemented for EmbeddingRetriever,
training will only use the first device provided in this list.
- `use_auth_token`: API token used to download private models from Huggingface. If this parameter is set to `True`,
the local token will be used, which must be previously created via `transformer-cli login`.
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained

View File

@ -1,11 +1,13 @@
import logging
from typing import Union
import torch
from transformers import AutoModelForQuestionAnswering
from haystack.modeling.model import adaptive_model as am
from haystack.modeling.model.language_model import LanguageModel
from haystack.modeling.model.prediction_head import QuestionAnsweringHead
from haystack.modeling.data_handler.processor import Processor
logger = logging.getLogger(__name__)
@ -46,10 +48,10 @@ class Converter:
@staticmethod
def convert_from_transformers(
model_name_or_path,
device,
revision=None,
task_type=None,
processor=None,
device: Union[str, torch.device],
revision: str = None,
task_type: str = "question_answering",
processor: Processor = None,
use_auth_token: Union[bool, str] = None,
**kwargs,
):
@ -65,14 +67,10 @@ class Converter:
- deepset/bert-large-uncased-whole-word-masking-squad2
See https://huggingface.co/models for full list
:param device: "cpu" or "cuda"
:param device: torch.device("cpu") or torch.device("cuda")
:param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:type revision: str
:param task_type: One of :
- 'question_answering'
More tasks coming soon ...
:param processor: populates prediction head with information coming from tasks
:type processor: Processor
Right now accepts only 'question_answering'.
:param processor: populates prediction head with information coming from tasks.
:return: AdaptiveModel
"""

View File

@ -785,7 +785,7 @@ class DistillationDataSilo(DataSilo):
self,
teacher_model: "FARMReader",
teacher_batch_size: int,
device: str,
device: torch.device,
processor: Processor,
batch_size: int,
eval_batch_size: Optional[int] = None,

View File

@ -20,11 +20,11 @@ class Evaluator:
Handles evaluation of a given model over a specified dataset.
"""
def __init__(self, data_loader: torch.utils.data.DataLoader, tasks, device: str, report: bool = True):
def __init__(self, data_loader: torch.utils.data.DataLoader, tasks, device: torch.device, report: bool = True):
"""
:param data_loader: The PyTorch DataLoader that will return batches of data from the evaluation dataset
:param tesks:
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda".
:param device: The device on which the tensors should be processed. Choose from torch.device("cpu") and torch.device("cuda").
:param report: Whether an eval report should be generated (e.g. classification report per class).
"""
self.data_loader = data_loader

View File

@ -128,7 +128,7 @@ class Inferencer:
use_fast: bool = True,
tokenizer_args: Dict = None,
multithreading_rust: bool = True,
devices: Optional[List[Union[int, str, torch.device]]] = None,
devices: Optional[List[torch.device]] = None,
use_auth_token: Union[bool, str] = None,
**kwargs,
):

View File

@ -169,7 +169,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
prediction_heads: List[PredictionHead],
embeds_dropout_prob: float,
lm_output_types: Union[str, List[str]],
device: str,
device: torch.device,
loss_aggregation_fn: Optional[Callable] = None,
):
"""
@ -182,7 +182,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
"per_sequence", a single embedding will be extracted to represent the full
input sequence. Can either be a single string, or a list of strings,
one for each prediction head.
:param device: The device on which this model will operate. Either "cpu" or "cuda".
:param device: The device on which this model will operate. Either torch.device("cpu") or torch.device("cuda").
:param loss_aggregation_fn: Function to aggregate the loss of multiple prediction heads.
Input: loss_per_head (list of tensors), global_step (int), batch (dict)
Output: aggregated loss (tensor)
@ -258,13 +258,13 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
# Need to save config and pipeline
@classmethod
def load( # type: ignore
def load( # type: ignore
cls,
load_dir: Union[str, Path], # type: ignore
device: str, # type: ignore
strict: bool = True, # type: ignore
lm_name: Optional[str] = None, # type: ignore
processor: Optional[Processor] = None, # type: ignore
load_dir: Union[str, Path],
device: Union[str, torch.device],
strict: bool = True,
lm_name: Optional[str] = None,
processor: Optional[Processor] = None,
):
"""
Loads an AdaptiveModel from a directory. The directory must contain:
@ -277,12 +277,13 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
* vocab.txt vocab file for language model, turning text to Wordpiece Tokens
:param load_dir: Location where the AdaptiveModel is stored.
:param device: To which device we want to sent the model, either cpu or cuda.
:param device: To which device we want to sent the model, either torch.device("cpu") or torch.device("cuda").
:param lm_name: The name to assign to the loaded language model.
:param strict: Whether to strictly enforce that the keys loaded from saved model match the ones in
the PredictionHead (see torch.nn.module.load_state_dict()).
:param processor: Processor to populate prediction head with information coming from tasks.
"""
device = torch.device(device)
# Language Model
if lm_name:
language_model = LanguageModel.load(load_dir, haystack_lm_name=lm_name)
@ -489,9 +490,9 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
def convert_from_transformers(
cls,
model_name_or_path: Union[str, Path],
device: str,
device: torch.device,
revision: Optional[str] = None,
task_type: Optional[str] = None,
task_type: str = "question_answering",
processor: Optional[Processor] = None,
use_auth_token: Optional[Union[bool, str]] = None,
**kwargs,
@ -509,12 +510,9 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
See https://huggingface.co/models for full list
:param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param device: "cpu" or "cuda"
:param task_type: One of :
- 'question_answering'
More tasks coming soon ...
:param device: On which hardware the conversion should take place. Choose from torch.device("cpu") or torch.device("cuda")
:param task_type: 'question_answering'. More tasks coming soon ...
:param processor: Processor to populate prediction head with information coming from tasks.
:type processor: Processor
:return: AdaptiveModel
"""
return conv.Converter.convert_from_transformers(
@ -570,7 +568,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
tokenizer_name_or_path=model_name, task_type=task_type, max_seq_len=256, doc_stride=128, use_fast=True
)
processor.save(output_path)
model = AdaptiveModel.convert_from_transformers(model_name, device="cpu", task_type=task_type)
model = AdaptiveModel.convert_from_transformers(model_name, device=torch.device("cpu"), task_type=task_type)
model.save(output_path)
os.remove(output_path / "language_model.bin") # remove the actual PyTorch model(only configs are required)
@ -617,14 +615,14 @@ class ONNXAdaptiveModel(BaseAdaptiveModel):
language_model_class: str,
language: str,
prediction_heads: List[PredictionHead],
device: str,
device: torch.device,
):
"""
:param onnx_session: ? # TODO
:param language_model_class: Class of LanguageModel
:param langauge: Language the model is trained for.
:param prediction_heads: A list of models that take embeddings and return logits for a given task.
:param device: The device on which this model will operate. Either "cpu" or "cuda".
:param device: The device on which this model will operate. Either torch.device("cpu") or torch.device("cuda").
"""
import onnxruntime
@ -642,13 +640,14 @@ class ONNXAdaptiveModel(BaseAdaptiveModel):
self.device = device
@classmethod
def load(cls, load_dir: Union[str, Path], device: str, **kwargs): # type: ignore
def load(cls, load_dir: Union[str, Path], device: Union[str, torch.device], **kwargs): # type: ignore
"""
Loads an ONNXAdaptiveModel from a directory.
:param load_dir: Location where the ONNXAdaptiveModel is stored.
:param device: The device on which this model will operate. Either "cpu" or "cuda".
:param device: The device on which this model will operate. Either torch.device("cpu") or torch.device("cuda").
"""
device = torch.device(device)
load_dir = Path(load_dir)
import onnxruntime
@ -657,7 +656,11 @@ class ONNXAdaptiveModel(BaseAdaptiveModel):
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs.
sess_options.intra_op_num_threads = multiprocessing.cpu_count()
onnx_session = onnxruntime.InferenceSession(str(load_dir / "model.onnx"), sess_options)
providers = kwargs.get(
"providers", ["CPUExecutionProvider"] if device.type == "cpu" else ["CUDAExecutionProvider"]
)
onnx_session = onnxruntime.InferenceSession(str(load_dir / "model.onnx"), sess_options, providers=providers)
# Prediction heads
_, ph_config_files = cls._get_prediction_head_files(load_dir, strict=False)

View File

@ -36,7 +36,7 @@ class BiAdaptiveModel(nn.Module):
language_model2: LanguageModel,
prediction_heads: List[PredictionHead],
embeds_dropout_prob: float = 0.1,
device: str = "cuda",
device: torch.device = torch.device("cuda"),
lm1_output_types: Union[str, List[str]] = ["per_sequence"],
lm2_output_types: Union[str, List[str]] = ["per_sequence"],
loss_aggregation_fn: Optional[Callable] = None,
@ -57,7 +57,7 @@ class BiAdaptiveModel(nn.Module):
"per_sequence", a single embedding will be extracted to represent the full
input sequence. Can either be a single string, or a list of strings,
one for each prediction head.
:param device: The device on which this model will operate. Either "cpu" or "cuda".
:param device: The device on which this model will operate. Either torch.device("cpu") or torch.device("cuda").
:param loss_aggregation_fn: Function to aggregate the loss of multiple prediction heads.
Input: loss_per_head (list of tensors), global_step (int), batch (dict)
Output: aggregated loss (tensor)
@ -108,7 +108,7 @@ class BiAdaptiveModel(nn.Module):
def load(
cls,
load_dir: Path,
device: str,
device: torch.device,
strict: bool = False,
lm1_name: str = "lm1",
lm2_name: str = "lm2",
@ -130,7 +130,7 @@ class BiAdaptiveModel(nn.Module):
* special_tokens_map.json
:param load_dir: Location where adaptive model is stored.
:param device: To which device we want to sent the model, either cpu or cuda.
:param device: To which device we want to sent the model, either torch.device("cpu") or torch.device("cuda").
:param lm1_name: The name to assign to the first loaded language model (for encoding queries).
:param lm2_name: The name to assign to the second loaded language model (for encoding context/passages).
:param strict: Whether to strictly enforce that the keys loaded from saved model match the ones in
@ -432,8 +432,8 @@ class BiAdaptiveModel(nn.Module):
cls,
model_name_or_path1: Union[str, Path],
model_name_or_path2: Union[str, Path],
device: str,
task_type: str,
device: torch.device,
task_type: str = "text_similarity",
processor: Optional[Processor] = None,
similarity_function: str = "dot_product",
):
@ -451,9 +451,8 @@ class BiAdaptiveModel(nn.Module):
Exemplary public names:
- facebook/dpr-ctx_encoder-single-nq-base
- deepset/bert-large-uncased-whole-word-masking-squad2
:param device: "cpu" or "cuda"
:param task_type: 'text_similarity'
More tasks coming soon ...
:param device: On which hardware the conversion is going to run on. Either torch.device("cpu") or torch.device("cuda")
:param task_type: 'text_similarity' More tasks coming soon ...
:param processor: populates prediction head with information coming from tasks
:type processor: Processor
:return: AdaptiveModel

View File

@ -1,5 +1,5 @@
# TODO analyse if this optimization is needed or whether we can use HF transformers code
from typing import Dict, Any
from typing import Dict, Any, Optional
import inspect
import logging
@ -73,7 +73,7 @@ def initialize_optimizer(
model: AdaptiveModel,
n_batches: int,
n_epochs: int,
device,
device: torch.device,
learning_rate: float,
optimizer_opts: Dict[Any, Any] = None,
schedule_opts: Dict[Any, Any] = None,
@ -90,7 +90,7 @@ def initialize_optimizer(
:param model: model to optimize (e.g. trimming weights to fp16 / mixed precision)
:param n_batches: number of batches for training
:param n_epochs: number of epochs for training
:param device:
:param device: Which hardware will be used by the optimizer. Either torch.device("cpu") or torch.device("cuda").
:param learning_rate: Learning rate
:param optimizer_opts: Dict to customize the optimizer. Choose any optimizer available from torch.optim, apex.optimizers or
transformers.optimization by supplying the class name and the parameters for the constructor.
@ -295,14 +295,20 @@ def get_scheduler(optimizer, opts):
return scheduler
def optimize_model(model, device, local_rank, optimizer=None, distributed=False, use_amp=None):
def optimize_model(
model: "AdaptiveModel",
device: torch.device,
local_rank: int,
optimizer=None,
distributed: Optional[bool] = False,
use_amp: Optional[str] = None,
):
"""
Wraps MultiGPU or distributed usage around a model
No support for ONNX models
:param model: model to optimize (e.g. trimming weights to fp16 / mixed precision)
:type model: AdaptiveModel
:param device: either gpu or cpu, get the device from initialize_device_settings()
:param device: either torch.device("cpu") or torch.device("cuda"). Get the device from `initialize_device_settings()`
:param distributed: Whether training on distributed machines
:param local_rank: rank of the machine in a distributed setting
:param use_amp: Optimization level of nvidia's automatic mixed precision (AMP). The higher the level, the faster the model.

View File

@ -43,7 +43,7 @@ class TriAdaptiveModel(nn.Module):
language_model3: LanguageModel,
prediction_heads: List[PredictionHead],
embeds_dropout_prob: float = 0.1,
device: str = "cuda",
device: torch.device = torch.device("cuda"),
lm1_output_types: Union[str, List[str]] = ["per_sequence"],
lm2_output_types: Union[str, List[str]] = ["per_sequence"],
lm3_output_types: Union[str, List[str]] = ["per_sequence"],
@ -71,7 +71,7 @@ class TriAdaptiveModel(nn.Module):
"per_sequence", a single embedding will be extracted to represent the full
input sequence. Can either be a single string, or a list of strings,
one for each prediction head.
:param device: The device on which this model will operate. Either "cpu" or "cuda".
:param device: The device on which this model will operate. Either torch.device("cpu") or torch.device("cuda").
:param loss_aggregation_fn: Function to aggregate the loss of multiple prediction heads.
Input: loss_per_head (list of tensors), global_step (int), batch (dict)
Output: aggregated loss (tensor)
@ -129,7 +129,7 @@ class TriAdaptiveModel(nn.Module):
def load(
cls,
load_dir: Path,
device: str,
device: torch.device,
strict: bool = False,
lm1_name: str = "lm1",
lm2_name: str = "lm2",
@ -155,7 +155,7 @@ class TriAdaptiveModel(nn.Module):
* special_tokens_map.json
:param load_dir: Location where the TriAdaptiveModel is stored.
:param device: To which device we want to sent the model, either cpu or cuda.
:param device: To which device we want to sent the model, either torch.device("cpu") or torch.device("cuda").
:param lm1_name: The name to assign to the first loaded language model (for encoding queries).
:param lm2_name: The name to assign to the second loaded language model (for encoding context/passages).
:param lm3_name: The name to assign to the second loaded language model (for encoding tables).

View File

@ -126,7 +126,7 @@ class Trainer:
data_silo: DataSilo,
epochs: int,
n_gpu: int,
device,
device: torch.device,
lr_schedule=None,
evaluate_every: int = 100,
eval_report: bool = True,
@ -152,7 +152,7 @@ class Trainer:
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
:param epochs: How many times the training procedure will loop through the train dataset
:param n_gpu: The number of gpus available for training and evaluation.
:param device: The device on which the train, dev and test tensors should be hosted. Choose from "cpu" and "cuda".
:param device: The device on which the train, dev and test tensors should be hosted. Choose from torch.device("cpu") and torch.device("cuda").
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
:param evaluate_every: Perform dev set evaluation after this many steps of training.
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
@ -660,7 +660,7 @@ class DistillationTrainer(Trainer):
data_silo: DistillationDataSilo,
epochs: int,
n_gpu: int,
device: str,
device: torch.device,
lr_schedule: Optional["_LRScheduler"] = None,
evaluate_every: int = 100,
eval_report: bool = True,
@ -691,7 +691,7 @@ class DistillationTrainer(Trainer):
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
:param epochs: How many times the training procedure will loop through the train dataset
:param n_gpu: The number of gpus available for training and evaluation.
:param device: The device on which the train, dev and test tensors should be hosted. Choose from "cpu" and "cuda".
:param device: The device on which the train, dev and test tensors should be hosted. Choose from torch.device("cpu") and torch.device("cuda").
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
:param evaluate_every: Perform dev set evaluation after this many steps of training.
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating
@ -833,7 +833,7 @@ class TinyBERTDistillationTrainer(Trainer):
:param data_silo: A DataSilo object that will contain the train, dev and test datasets as PyTorch DataLoaders
:param epochs: How many times the training procedure will loop through the train dataset
:param n_gpu: The number of gpus available for training and evaluation.
:param device: The device on which the train, dev and test tensors should be hosted. Choose from "cpu" and "cuda".
:param device: The device on which the train, dev and test tensors should be hosted. Choose from torch.device("cpu") and torch.device("cuda").
:param lr_schedule: An optional scheduler object that can regulate the learning rate of the optimizer
:param evaluate_every: Perform dev set evaluation after this many steps of training.
:param eval_report: If evaluate_every is not 0, specifies if an eval report should be generated when evaluating

View File

@ -41,7 +41,7 @@ class SentenceTransformersRanker(BaseRanker):
model_version: Optional[str] = None,
top_k: int = 10,
use_gpu: bool = True,
devices: Optional[List[Union[int, str, torch.device]]] = None,
devices: Optional[List[Union[str, torch.device]]] = None,
):
"""
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
@ -50,16 +50,20 @@ class SentenceTransformersRanker(BaseRanker):
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param top_k: The maximum number of documents to return
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
:param devices: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
The strings will be converted into pytorch devices, so use the string notation described here:
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
(e.g. ["cuda:0"]).
"""
super().__init__()
self.top_k = top_k
if devices is not None:
self.devices = devices
self.devices = [torch.device(device) for device in devices]
else:
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True)
self.transformer_model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=model_name_or_path, revision=model_version
)

View File

@ -777,7 +777,9 @@ class FARMReader(BaseReader):
return result
def eval_on_file(self, data_dir: str, test_filename: str, device: Optional[str] = None):
def eval_on_file(
self, data_dir: Union[Path, str], test_filename: str, device: Optional[Union[str, torch.device]] = None
):
"""
Performs evaluation on a SQuAD-formatted file.
Returns a dict containing the following metrics:
@ -786,14 +788,16 @@ class FARMReader(BaseReader):
- "top_n_accuracy": Proportion of predicted answers that overlap with correct answer
:param data_dir: The directory in which the test set can be found
:type data_dir: Path or str
:param test_filename: The name of the file containing the test data in SQuAD format.
:type test_filename: str
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
:type device: str
:param device: The device on which the tensors should be processed.
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
or use the Reader's device by default.
"""
if device is None:
device = self.devices[0]
else:
device = torch.device(device)
eval_processor = SquadProcessor(
tokenizer=self.inferencer.processor.tokenizer,
max_seq_len=self.inferencer.processor.max_seq_len,
@ -822,7 +826,7 @@ class FARMReader(BaseReader):
def eval(
self,
document_store: BaseDocumentStore,
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
label_index: str = "label",
doc_index: str = "eval_document",
label_origin: str = "gold-label",
@ -836,7 +840,9 @@ class FARMReader(BaseReader):
- "top_n_accuracy": Proportion of predicted answers that overlap with correct answer
:param document_store: DocumentStore containing the evaluation documents
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
:param device: The device on which the tensors should be processed.
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
or use the Reader's device by default.
:param label_index: Index/Table name where labeled questions are stored
:param doc_index: Index/Table name where documents that are used for evaluation are stored
:param label_origin: Field name where the gold labels are stored
@ -844,6 +850,9 @@ class FARMReader(BaseReader):
"""
if device is None:
device = self.devices[0]
else:
device = torch.device(device)
if self.top_k_per_candidate != 4:
logger.info(
f"Performing Evaluation using top_k_per_candidate = {self.top_k_per_candidate} \n"
@ -1012,7 +1021,7 @@ class FARMReader(BaseReader):
def calibrate_confidence_scores(
self,
document_store: BaseDocumentStore,
device: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
label_index: str = "label",
doc_index: str = "eval_document",
label_origin: str = "gold_label",
@ -1021,7 +1030,9 @@ class FARMReader(BaseReader):
Calibrates confidence scores on evaluation documents in the DocumentStore.
:param document_store: DocumentStore containing the evaluation documents
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda" or use the Reader's device by default.
:param device: The device on which the tensors should be processed.
Choose from torch.device("cpu") and torch.device("cuda") (or simply "cpu" or "cuda")
or use the Reader's device by default.
:param label_index: Index/Table name where labeled questions are stored
:param doc_index: Index/Table name where documents that are used for evaluation are stored
:param label_origin: Field name where the gold labels are stored

View File

@ -54,7 +54,7 @@ class DensePassageRetriever(BaseRetriever):
similarity_function: str = "dot_product",
global_loss_buffer_size: int = 150000,
progress_bar: bool = True,
devices: Optional[List[Union[int, str, torch.device]]] = None,
devices: Optional[List[Union[str, torch.device]]] = None,
use_auth_token: Optional[Union[str, bool]] = None,
):
"""
@ -102,8 +102,11 @@ class DensePassageRetriever(BaseRetriever):
Increase if errors like "encoded data exceeds max_size ..." come up
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
:param devices: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
These strings will be converted into pytorch devices, so use the string notation described here:
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
(e.g. ["cuda:0"]). Note: as multi-GPU training is currently not implemented for DPR, training
will only use the first device provided in this list.
:param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
the local token will be used, which must be previously created via `transformer-cli login`.
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
@ -111,7 +114,7 @@ class DensePassageRetriever(BaseRetriever):
super().__init__()
if devices is not None:
self.devices = devices
self.devices = [torch.device(device) for device in devices]
else:
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True)
@ -193,7 +196,7 @@ class DensePassageRetriever(BaseRetriever):
embeds_dropout_prob=0.1,
lm1_output_types=["per_sequence"],
lm2_output_types=["per_sequence"],
device=str(self.devices[0]),
device=self.devices[0],
)
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
@ -548,7 +551,7 @@ class TableTextRetriever(BaseRetriever):
similarity_function: str = "dot_product",
global_loss_buffer_size: int = 150000,
progress_bar: bool = True,
devices: Optional[List[Union[int, str, torch.device]]] = None,
devices: Optional[List[Union[str, torch.device]]] = None,
use_auth_token: Optional[Union[str, bool]] = None,
):
"""
@ -582,8 +585,11 @@ class TableTextRetriever(BaseRetriever):
Increase if errors like "encoded data exceeds max_size ..." come up
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
:param devices: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
These strings will be converted into pytorch devices, so use the string notation described here:
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
(e.g. ["cuda:0"]). Note: as multi-GPU training is currently not implemented for TableTextRetriever,
training will only use the first device provided in this list.
:param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
the local token will be used, which must be previously created via `transformer-cli login`.
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
@ -591,7 +597,7 @@ class TableTextRetriever(BaseRetriever):
super().__init__()
if devices is not None:
self.devices = devices
self.devices = [torch.device(device) for device in devices]
else:
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True)
@ -699,7 +705,7 @@ class TableTextRetriever(BaseRetriever):
lm1_output_types=["per_sequence"],
lm2_output_types=["per_sequence"],
lm3_output_types=["per_sequence"],
device=str(self.devices[0]),
device=self.devices[0],
)
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
@ -1075,7 +1081,7 @@ class EmbeddingRetriever(BaseRetriever):
emb_extraction_layer: int = -1,
top_k: int = 10,
progress_bar: bool = True,
devices: Optional[List[Union[int, str, torch.device]]] = None,
devices: Optional[List[Union[str, torch.device]]] = None,
use_auth_token: Optional[Union[str, bool]] = None,
):
"""
@ -1101,8 +1107,11 @@ class EmbeddingRetriever(BaseRetriever):
Default: -1 (very last layer).
:param top_k: How many documents to return per query.
:param progress_bar: If true displays progress bar during embedding.
:param devices: List of GPU devices to limit inference to certain GPUs and not use all available ones (e.g. ["cuda:0"]).
As multi-GPU training is currently not implemented for DPR, training will only use the first device provided in this list.
:param devices: List of GPU (or CPU) devices, to limit inference to certain GPUs and not use all available ones
These strings will be converted into pytorch devices, so use the string notation described here:
https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device
(e.g. ["cuda:0"]). Note: As multi-GPU training is currently not implemented for EmbeddingRetriever,
training will only use the first device provided in this list.
:param use_auth_token: API token used to download private models from Huggingface. If this parameter is set to `True`,
the local token will be used, which must be previously created via `transformer-cli login`.
Additional information can be found here https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
@ -1110,7 +1119,7 @@ class EmbeddingRetriever(BaseRetriever):
super().__init__()
if devices is not None:
self.devices = devices
self.devices = [torch.device(device) for device in devices]
else:
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True)

View File

@ -23,6 +23,7 @@ Arguments:
model: Huggingface MLM model identifier.
"""
from typing import Tuple, List, Union
import torch
from torch.nn import functional as F
@ -37,13 +38,14 @@ import argparse
import json
import logging
from tqdm import tqdm
from typing import Tuple, List
logger = logging.getLogger(__name__)
def load_glove(
glove_path: Path = Path("glove.txt"), vocab_size: int = 100_000, device: str = "cpu:0"
glove_path: Path = Path("glove.txt"),
vocab_size: int = 100_000,
device: Union[str, torch.device] = torch.device("cpu:0"),
) -> Tuple[dict, dict, torch.Tensor]:
"""Loads the GloVe vectors and returns a mapping from words to their GloVe vector indices and the other way around."""
@ -112,8 +114,9 @@ def get_replacements(
text: str,
word_possibilities: int = 20,
batch_size: int = 16,
device: str = "cpu:0",
device: torch.device = torch.device("cpu:0"),
) -> List[List[str]]:
"""Returns a list of possible replacements for each word in the text."""
input_ids, words, word_subword_mapping = tokenize_and_extract_words(text, tokenizer)
@ -179,8 +182,9 @@ def augment(
word_possibilities: int = 20,
replace_probability: float = 0.4,
batch_size: int = 16,
device: str = "cpu:0",
device: Union[str, torch.device] = torch.device("cpu:0"),
) -> List[str]:
device = torch.device(device)
# returns a list of different augmented versions of the text
replacements = get_replacements(
glove_word_id_mapping=word_id_mapping,
@ -211,16 +215,17 @@ def augment(
def augment_squad(
squad_path: Path,
output_path: Path,
glove_path: Path = Path("glove.txt"),
model: str = "bert-base-uncased",
tokenizer: str = "bert-base-uncased",
glove_path: Path = Path("glove.txt"),
multiplication_factor: int = 20,
word_possibilities: int = 20,
replace_probability: float = 0.4,
device: str = "cpu:0",
device: Union[str, torch.device] = "cpu:0",
batch_size: int = 16,
):
"""Loads a squad dataset, augments the contexts, and saves the result in SQuAD format."""
device = torch.device(device)
# loading model and tokenizer
transformers_model = AutoModelForMaskedLM.from_pretrained(model)
transformers_model.to(device)