mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 09:49:48 +00:00 
			
		
		
		
	fix: ONNX FARMReader model conversion is broken (#3211)
This commit is contained in:
		
							parent
							
								
									b579b9d54a
								
							
						
					
					
						commit
						9582a423a2
					
				| @ -13,7 +13,12 @@ from transformers import AutoConfig, AutoModelForQuestionAnswering | |||||||
| from transformers.convert_graph_to_onnx import convert, quantize as quantize_model | from transformers.convert_graph_to_onnx import convert, quantize as quantize_model | ||||||
| 
 | 
 | ||||||
| from haystack.modeling.data_handler.processor import Processor | from haystack.modeling.data_handler.processor import Processor | ||||||
| from haystack.modeling.model.language_model import get_language_model, LanguageModel | from haystack.modeling.model.language_model import ( | ||||||
|  |     get_language_model, | ||||||
|  |     LanguageModel, | ||||||
|  |     _get_model_type, | ||||||
|  |     capitalize_model_type, | ||||||
|  | ) | ||||||
| from haystack.modeling.model.prediction_head import PredictionHead, QuestionAnsweringHead | from haystack.modeling.model.prediction_head import PredictionHead, QuestionAnsweringHead | ||||||
| from haystack.utils.experiment_tracking import Tracker as tracker | from haystack.utils.experiment_tracking import Tracker as tracker | ||||||
| 
 | 
 | ||||||
| @ -626,8 +631,8 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel): | |||||||
|                                https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained |                                https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained | ||||||
|         :return: None. |         :return: None. | ||||||
|         """ |         """ | ||||||
|         language_model_class = LanguageModel.get_language_model_class(model_name) |         model_type = capitalize_model_type(_get_model_type(model_name))  # type: ignore | ||||||
|         if language_model_class not in ["Bert", "Roberta", "XLMRoberta"]: |         if model_type not in ["Bert", "Roberta", "XMLRoberta"]: | ||||||
|             raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.") |             raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.") | ||||||
| 
 | 
 | ||||||
|         task_type_to_pipeline_map = {"question_answering": "question-answering"} |         task_type_to_pipeline_map = {"question_answering": "question-answering"} | ||||||
| @ -638,7 +643,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel): | |||||||
|             model=model_name, |             model=model_name, | ||||||
|             output=output_path / "model.onnx", |             output=output_path / "model.onnx", | ||||||
|             opset=opset_version, |             opset=opset_version, | ||||||
|             use_external_format=True if language_model_class == "XLMRoberta" else False, |             use_external_format=True if model_type == "XMLRoberta" else False, | ||||||
|             use_auth_token=use_auth_token, |             use_auth_token=use_auth_token, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| @ -661,7 +666,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel): | |||||||
|         onnx_model_config = { |         onnx_model_config = { | ||||||
|             "task_type": task_type, |             "task_type": task_type, | ||||||
|             "onnx_opset_version": opset_version, |             "onnx_opset_version": opset_version, | ||||||
|             "language_model_class": language_model_class, |             "language_model_class": model_type, | ||||||
|             "language": model.language_model.language, |             "language": model.language_model.language, | ||||||
|         } |         } | ||||||
|         with open(output_path / "onnx_model_config.json", "w") as f: |         with open(output_path / "onnx_model_config.json", "w") as f: | ||||||
|  | |||||||
| @ -1,4 +1,6 @@ | |||||||
| import math | import math | ||||||
|  | import os | ||||||
|  | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| import pytest | import pytest | ||||||
| from haystack.modeling.data_handler.inputs import QAInput, Question | from haystack.modeling.data_handler.inputs import QAInput, Question | ||||||
| @ -277,3 +279,16 @@ When beer is distilled, the resulting liquor is a form of whisky.[12] | |||||||
|             assert answer.score == qa_cand.confidence |             assert answer.score == qa_cand.confidence | ||||||
|         else: |         else: | ||||||
|             assert answer.score == qa_cand.score |             assert answer.score == qa_cand.score | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize("model_name", ["deepset/roberta-base-squad2", "deepset/bert-base-uncased-squad2"]) | ||||||
|  | def test_farm_reader_onnx_conversion_and_inference(model_name, tmpdir, docs): | ||||||
|  |     FARMReader.convert_to_onnx(model_name=model_name, output_path=Path(tmpdir, "onnx")) | ||||||
|  |     assert os.path.exists(Path(tmpdir, "onnx", "model.onnx")) | ||||||
|  |     assert os.path.exists(Path(tmpdir, "onnx", "processor_config.json")) | ||||||
|  |     assert os.path.exists(Path(tmpdir, "onnx", "onnx_model_config.json")) | ||||||
|  |     assert os.path.exists(Path(tmpdir, "onnx", "language_model_config.json")) | ||||||
|  | 
 | ||||||
|  |     reader = FARMReader(str(Path(tmpdir, "onnx"))) | ||||||
|  |     result = reader.predict(query="Where does Paul live?", documents=[docs[0]]) | ||||||
|  |     assert result["answers"][0].answer == "New York" | ||||||
|  | |||||||
| @ -1,5 +1,4 @@ | |||||||
| import logging | import logging | ||||||
| import time |  | ||||||
| from math import isclose | from math import isclose | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| @ -7,7 +6,6 @@ import pandas as pd | |||||||
| from haystack.document_stores.base import BaseDocumentStore | from haystack.document_stores.base import BaseDocumentStore | ||||||
| from haystack.document_stores.memory import InMemoryDocumentStore | from haystack.document_stores.memory import InMemoryDocumentStore | ||||||
| import pytest | import pytest | ||||||
| from pathlib import Path |  | ||||||
| from elasticsearch import Elasticsearch | from elasticsearch import Elasticsearch | ||||||
| 
 | 
 | ||||||
| from haystack.document_stores import WeaviateDocumentStore | from haystack.document_stores import WeaviateDocumentStore | ||||||
| @ -16,12 +14,7 @@ from haystack.schema import Document | |||||||
| from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore | from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore | ||||||
| from haystack.document_stores.faiss import FAISSDocumentStore | from haystack.document_stores.faiss import FAISSDocumentStore | ||||||
| from haystack.document_stores import MilvusDocumentStore | from haystack.document_stores import MilvusDocumentStore | ||||||
| from haystack.nodes.retriever.dense import ( | from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever | ||||||
|     DensePassageRetriever, |  | ||||||
|     EmbeddingRetriever, |  | ||||||
|     TableTextRetriever, |  | ||||||
|     MultihopEmbeddingRetriever, |  | ||||||
| ) |  | ||||||
| from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever | from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever | ||||||
| from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast | from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Vladimir Blagojevic
						Vladimir Blagojevic