mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 01:39:45 +00:00 
			
		
		
		
	Add ONNXRuntime support (#157)
This commit is contained in:
		
							parent
							
								
									54e85e586e
								
							
						
					
					
						commit
						03acb1ee32
					
				| @ -242,7 +242,7 @@ class FARMReader(BaseReader): | |||||||
| 
 | 
 | ||||||
|         # get answers from QA model |         # get answers from QA model | ||||||
|         predictions = self.inferencer.inference_from_dicts( |         predictions = self.inferencer.inference_from_dicts( | ||||||
|             dicts=input_dicts, rest_api_schema=True, multiprocessing_chunksize=1 |             dicts=input_dicts, return_json=True, multiprocessing_chunksize=1 | ||||||
|         ) |         ) | ||||||
|         # assemble answers from all the different documents & format them. |         # assemble answers from all the different documents & format them. | ||||||
|         # For the "no answer" option, we collect all no_ans_gaps and decide how likely |         # For the "no answer" option, we collect all no_ans_gaps and decide how likely | ||||||
| @ -442,3 +442,23 @@ class FARMReader(BaseReader): | |||||||
|             ) |             ) | ||||||
|         predictions = self.predict(question, documents, top_k) |         predictions = self.predict(question, documents, top_k) | ||||||
|         return predictions |         return predictions | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def convert_to_onnx(cls, model_name_or_path, opset_version: int = 11, optimize_for: Optional[str] = None): | ||||||
|  |         """ | ||||||
|  |         Convert a PyTorch BERT model to ONNX format and write to ./onnx-export dir. The converted ONNX model | ||||||
|  |         can be loaded with in the `FARMReader` using the export path as `model_name_or_path` param. | ||||||
|  | 
 | ||||||
|  |         Usage: | ||||||
|  |             >>> from haystack.reader.farm import FARMReader | ||||||
|  |             >>> FARMReader.convert_to_onnx(model_name_or_path="deepset/bert-base-cased-squad2", optimize_for="gpu_tensor_core") | ||||||
|  |             >>> FARMReader(model_name_or_path=Path("onnx-export")) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         :param opset_version: ONNX opset version | ||||||
|  |         :param optimize_for: optimize the exported model for a target device. Available options | ||||||
|  |                              are "gpu_tensor_core" (GPUs with tensor core like V100 or T4), | ||||||
|  |                              "gpu_without_tensor_core" (most other GPUs), and "cpu". | ||||||
|  |         """ | ||||||
|  |         inferencer = Inferencer.load(model_name_or_path, task_type="question_answering") | ||||||
|  |         inferencer.model.convert_to_onnx(output_path=Path("onnx-export"), opset_version=opset_version, optimize_for=optimize_for) | ||||||
|  | |||||||
| @ -1,4 +1,4 @@ | |||||||
| farm==0.4.3 | farm==0.4.4 | ||||||
| --find-links=https://download.pytorch.org/whl/torch_stable.html | --find-links=https://download.pytorch.org/whl/torch_stable.html | ||||||
| fastapi | fastapi | ||||||
| uvicorn | uvicorn | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Tanay Soni
						Tanay Soni