mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-03 19:29:32 +00:00 
			
		
		
		
	Update ONNX conversion for FARMReader (#438)
This commit is contained in:
		
							parent
							
								
									bb4802ae6a
								
							
						
					
					
						commit
						93fd4aa72f
					
				@ -12,7 +12,7 @@ from farm.data_handler.inputs import QAInput, Question
 | 
			
		||||
from farm.infer import QAInferencer
 | 
			
		||||
from farm.modeling.optimization import initialize_optimizer
 | 
			
		||||
from farm.modeling.predictions import QAPred, QACandidate
 | 
			
		||||
from farm.modeling.adaptive_model import BaseAdaptiveModel
 | 
			
		||||
from farm.modeling.adaptive_model import BaseAdaptiveModel, AdaptiveModel
 | 
			
		||||
from farm.train import Trainer
 | 
			
		||||
from farm.eval import Evaluator
 | 
			
		||||
from farm.utils import set_all_seeds, initialize_device_settings
 | 
			
		||||
@ -567,7 +567,15 @@ class FARMReader(BaseReader):
 | 
			
		||||
        return predictions
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def convert_to_onnx(cls, model_name_or_path, opset_version: int = 11, optimize_for: Optional[str] = None):
 | 
			
		||||
    def convert_to_onnx(
 | 
			
		||||
            cls,
 | 
			
		||||
            model_name: str,
 | 
			
		||||
            output_path: Path,
 | 
			
		||||
            convert_to_float16: bool = False,
 | 
			
		||||
            quantize: bool = False,
 | 
			
		||||
            task_type: str = "question_answering",
 | 
			
		||||
            opset_version: int = 11
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Convert a PyTorch BERT model to ONNX format and write to ./onnx-export dir. The converted ONNX model
 | 
			
		||||
        can be loaded with in the `FARMReader` using the export path as `model_name_or_path` param.
 | 
			
		||||
@ -575,14 +583,26 @@ class FARMReader(BaseReader):
 | 
			
		||||
        Usage:
 | 
			
		||||
 | 
			
		||||
            `from haystack.reader.farm import FARMReader
 | 
			
		||||
            FARMReader.convert_to_onnx(model_name_or_path="deepset/bert-base-cased-squad2", optimize_for="gpu_tensor_core")
 | 
			
		||||
            FARMReader(model_name_or_path=Path("onnx-export"))`
 | 
			
		||||
 | 
			
		||||
            from pathlib import Path
 | 
			
		||||
            onnx_model_path = Path("roberta-onnx-model")
 | 
			
		||||
            FARMReader.convert_to_onnx(model_name="deepset/bert-base-cased-squad2", output_path=onnx_model_path)
 | 
			
		||||
            reader = FARMReader(onnx_model_path)`
 | 
			
		||||
 | 
			
		||||
        :param model_name: transformers model name
 | 
			
		||||
        :param output_path: Path to output the converted model
 | 
			
		||||
        :param convert_to_float16: Many models use float32 precision by default. With the half precision of float16,
 | 
			
		||||
                                   inference is faster on Nvidia GPUs with Tensor core like T4 or V100. On older GPUs,
 | 
			
		||||
                                   float32 could still be be more performant.
 | 
			
		||||
        :param quantize: convert floating point number to integers
 | 
			
		||||
        :param task_type: Type of task for the model. Available options: "question_answering" or "embeddings".
 | 
			
		||||
        :param opset_version: ONNX opset version
 | 
			
		||||
        :param optimize_for: Optimize the exported model for a target device. Available options
 | 
			
		||||
                             are "gpu_tensor_core" (GPUs with tensor core like V100 or T4),
 | 
			
		||||
                             "gpu_without_tensor_core" (most other GPUs), and "cpu".
 | 
			
		||||
        """
 | 
			
		||||
        inferencer = QAInferencer.load(model_name_or_path, task_type="question_answering")
 | 
			
		||||
        inferencer.model.convert_to_onnx(output_path=Path("onnx-export"), opset_version=opset_version, optimize_for=optimize_for)
 | 
			
		||||
        AdaptiveModel.convert_to_onnx(
 | 
			
		||||
            model_name=model_name,
 | 
			
		||||
            output_path=output_path,
 | 
			
		||||
            task_type=task_type,
 | 
			
		||||
            convert_to_float16=convert_to_float16,
 | 
			
		||||
            quantize=quantize,
 | 
			
		||||
            opset_version=opset_version
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user