mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-08 22:03:54 +00:00
Update ONNX conversion for FARMReader (#438)
This commit is contained in:
parent
bb4802ae6a
commit
93fd4aa72f
@ -12,7 +12,7 @@ from farm.data_handler.inputs import QAInput, Question
|
|||||||
from farm.infer import QAInferencer
|
from farm.infer import QAInferencer
|
||||||
from farm.modeling.optimization import initialize_optimizer
|
from farm.modeling.optimization import initialize_optimizer
|
||||||
from farm.modeling.predictions import QAPred, QACandidate
|
from farm.modeling.predictions import QAPred, QACandidate
|
||||||
from farm.modeling.adaptive_model import BaseAdaptiveModel
|
from farm.modeling.adaptive_model import BaseAdaptiveModel, AdaptiveModel
|
||||||
from farm.train import Trainer
|
from farm.train import Trainer
|
||||||
from farm.eval import Evaluator
|
from farm.eval import Evaluator
|
||||||
from farm.utils import set_all_seeds, initialize_device_settings
|
from farm.utils import set_all_seeds, initialize_device_settings
|
||||||
@ -567,7 +567,15 @@ class FARMReader(BaseReader):
|
|||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_to_onnx(cls, model_name_or_path, opset_version: int = 11, optimize_for: Optional[str] = None):
|
def convert_to_onnx(
|
||||||
|
cls,
|
||||||
|
model_name: str,
|
||||||
|
output_path: Path,
|
||||||
|
convert_to_float16: bool = False,
|
||||||
|
quantize: bool = False,
|
||||||
|
task_type: str = "question_answering",
|
||||||
|
opset_version: int = 11
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Convert a PyTorch BERT model to ONNX format and write to ./onnx-export dir. The converted ONNX model
|
Convert a PyTorch BERT model to ONNX format and write to ./onnx-export dir. The converted ONNX model
|
||||||
can be loaded with in the `FARMReader` using the export path as `model_name_or_path` param.
|
can be loaded with in the `FARMReader` using the export path as `model_name_or_path` param.
|
||||||
@ -575,14 +583,26 @@ class FARMReader(BaseReader):
|
|||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
`from haystack.reader.farm import FARMReader
|
`from haystack.reader.farm import FARMReader
|
||||||
FARMReader.convert_to_onnx(model_name_or_path="deepset/bert-base-cased-squad2", optimize_for="gpu_tensor_core")
|
from pathlib import Path
|
||||||
FARMReader(model_name_or_path=Path("onnx-export"))`
|
onnx_model_path = Path("roberta-onnx-model")
|
||||||
|
FARMReader.convert_to_onnx(model_name="deepset/bert-base-cased-squad2", output_path=onnx_model_path)
|
||||||
|
reader = FARMReader(onnx_model_path)`
|
||||||
|
|
||||||
|
:param model_name: transformers model name
|
||||||
|
:param output_path: Path to output the converted model
|
||||||
|
:param convert_to_float16: Many models use float32 precision by default. With the half precision of float16,
|
||||||
|
inference is faster on Nvidia GPUs with Tensor core like T4 or V100. On older GPUs,
|
||||||
|
float32 could still be be more performant.
|
||||||
|
:param quantize: convert floating point number to integers
|
||||||
|
:param task_type: Type of task for the model. Available options: "question_answering" or "embeddings".
|
||||||
:param opset_version: ONNX opset version
|
:param opset_version: ONNX opset version
|
||||||
:param optimize_for: Optimize the exported model for a target device. Available options
|
|
||||||
are "gpu_tensor_core" (GPUs with tensor core like V100 or T4),
|
|
||||||
"gpu_without_tensor_core" (most other GPUs), and "cpu".
|
|
||||||
"""
|
"""
|
||||||
inferencer = QAInferencer.load(model_name_or_path, task_type="question_answering")
|
AdaptiveModel.convert_to_onnx(
|
||||||
inferencer.model.convert_to_onnx(output_path=Path("onnx-export"), opset_version=opset_version, optimize_for=optimize_for)
|
model_name=model_name,
|
||||||
|
output_path=output_path,
|
||||||
|
task_type=task_type,
|
||||||
|
convert_to_float16=convert_to_float16,
|
||||||
|
quantize=quantize,
|
||||||
|
opset_version=opset_version
|
||||||
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user