mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 11:56:35 +00:00
Add ONNXRuntime support (#157)
This commit is contained in:
parent
54e85e586e
commit
03acb1ee32
@ -242,7 +242,7 @@ class FARMReader(BaseReader):
|
|||||||
|
|
||||||
# get answers from QA model
|
# get answers from QA model
|
||||||
predictions = self.inferencer.inference_from_dicts(
|
predictions = self.inferencer.inference_from_dicts(
|
||||||
dicts=input_dicts, rest_api_schema=True, multiprocessing_chunksize=1
|
dicts=input_dicts, return_json=True, multiprocessing_chunksize=1
|
||||||
)
|
)
|
||||||
# assemble answers from all the different documents & format them.
|
# assemble answers from all the different documents & format them.
|
||||||
# For the "no answer" option, we collect all no_ans_gaps and decide how likely
|
# For the "no answer" option, we collect all no_ans_gaps and decide how likely
|
||||||
@ -442,3 +442,23 @@ class FARMReader(BaseReader):
|
|||||||
)
|
)
|
||||||
predictions = self.predict(question, documents, top_k)
|
predictions = self.predict(question, documents, top_k)
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_to_onnx(cls, model_name_or_path, opset_version: int = 11, optimize_for: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Convert a PyTorch BERT model to ONNX format and write to ./onnx-export dir. The converted ONNX model
|
||||||
|
can be loaded with in the `FARMReader` using the export path as `model_name_or_path` param.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
>>> from haystack.reader.farm import FARMReader
|
||||||
|
>>> FARMReader.convert_to_onnx(model_name_or_path="deepset/bert-base-cased-squad2", optimize_for="gpu_tensor_core")
|
||||||
|
>>> FARMReader(model_name_or_path=Path("onnx-export"))
|
||||||
|
|
||||||
|
|
||||||
|
:param opset_version: ONNX opset version
|
||||||
|
:param optimize_for: optimize the exported model for a target device. Available options
|
||||||
|
are "gpu_tensor_core" (GPUs with tensor core like V100 or T4),
|
||||||
|
"gpu_without_tensor_core" (most other GPUs), and "cpu".
|
||||||
|
"""
|
||||||
|
inferencer = Inferencer.load(model_name_or_path, task_type="question_answering")
|
||||||
|
inferencer.model.convert_to_onnx(output_path=Path("onnx-export"), opset_version=opset_version, optimize_for=optimize_for)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
farm==0.4.3
|
farm==0.4.4
|
||||||
--find-links=https://download.pytorch.org/whl/torch_stable.html
|
--find-links=https://download.pytorch.org/whl/torch_stable.html
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
|
Loading…
x
Reference in New Issue
Block a user