Add ONNXRuntime support (#157)

This commit is contained in:
Tanay Soni 2020-06-18 17:47:16 +02:00 committed by GitHub
parent 54e85e586e
commit 03acb1ee32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 2 deletions

View File

@ -242,7 +242,7 @@ class FARMReader(BaseReader):
# get answers from QA model # get answers from QA model
predictions = self.inferencer.inference_from_dicts( predictions = self.inferencer.inference_from_dicts(
dicts=input_dicts, rest_api_schema=True, multiprocessing_chunksize=1 dicts=input_dicts, return_json=True, multiprocessing_chunksize=1
) )
# assemble answers from all the different documents & format them. # assemble answers from all the different documents & format them.
# For the "no answer" option, we collect all no_ans_gaps and decide how likely # For the "no answer" option, we collect all no_ans_gaps and decide how likely
@ -442,3 +442,23 @@ class FARMReader(BaseReader):
) )
predictions = self.predict(question, documents, top_k) predictions = self.predict(question, documents, top_k)
return predictions return predictions
@classmethod
def convert_to_onnx(cls, model_name_or_path, opset_version: int = 11, optimize_for: Optional[str] = None):
"""
Convert a PyTorch BERT model to ONNX format and write to ./onnx-export dir. The converted ONNX model
can be loaded with in the `FARMReader` using the export path as `model_name_or_path` param.
Usage:
>>> from haystack.reader.farm import FARMReader
>>> FARMReader.convert_to_onnx(model_name_or_path="deepset/bert-base-cased-squad2", optimize_for="gpu_tensor_core")
>>> FARMReader(model_name_or_path=Path("onnx-export"))
:param opset_version: ONNX opset version
:param optimize_for: optimize the exported model for a target device. Available options
are "gpu_tensor_core" (GPUs with tensor core like V100 or T4),
"gpu_without_tensor_core" (most other GPUs), and "cpu".
"""
inferencer = Inferencer.load(model_name_or_path, task_type="question_answering")
inferencer.model.convert_to_onnx(output_path=Path("onnx-export"), opset_version=opset_version, optimize_for=optimize_for)

View File

@ -1,4 +1,4 @@
farm==0.4.3 farm==0.4.4
--find-links=https://download.pytorch.org/whl/torch_stable.html --find-links=https://download.pytorch.org/whl/torch_stable.html
fastapi fastapi
uvicorn uvicorn