From 03acb1ee321f010c2acd86a88f9bc521e4f786b4 Mon Sep 17 00:00:00 2001 From: Tanay Soni Date: Thu, 18 Jun 2020 17:47:16 +0200 Subject: [PATCH] Add ONNXRuntime support (#157) --- haystack/reader/farm.py | 22 +++++++++++++++++++++- requirements.txt | 2 +- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index c6e16a87f..84d3afccb 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -242,7 +242,7 @@ class FARMReader(BaseReader): # get answers from QA model predictions = self.inferencer.inference_from_dicts( - dicts=input_dicts, rest_api_schema=True, multiprocessing_chunksize=1 + dicts=input_dicts, return_json=True, multiprocessing_chunksize=1 ) # assemble answers from all the different documents & format them. # For the "no answer" option, we collect all no_ans_gaps and decide how likely @@ -442,3 +442,23 @@ class FARMReader(BaseReader): ) predictions = self.predict(question, documents, top_k) return predictions + + @classmethod + def convert_to_onnx(cls, model_name_or_path, opset_version: int = 11, optimize_for: Optional[str] = None): + """ + Convert a PyTorch BERT model to ONNX format and write to ./onnx-export dir. The converted ONNX model + can be loaded with in the `FARMReader` using the export path as `model_name_or_path` param. + + Usage: + >>> from haystack.reader.farm import FARMReader + >>> FARMReader.convert_to_onnx(model_name_or_path="deepset/bert-base-cased-squad2", optimize_for="gpu_tensor_core") + >>> FARMReader(model_name_or_path=Path("onnx-export")) + + + :param opset_version: ONNX opset version + :param optimize_for: optimize the exported model for a target device. Available options + are "gpu_tensor_core" (GPUs with tensor core like V100 or T4), + "gpu_without_tensor_core" (most other GPUs), and "cpu". + """ + inferencer = Inferencer.load(model_name_or_path, task_type="question_answering") + inferencer.model.convert_to_onnx(output_path=Path("onnx-export"), opset_version=opset_version, optimize_for=optimize_for) diff --git a/requirements.txt b/requirements.txt index 0da3ad3ae..0d58053fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -farm==0.4.3 +farm==0.4.4 --find-links=https://download.pytorch.org/whl/torch_stable.html fastapi uvicorn