diff --git a/haystack/modeling/model/adaptive_model.py b/haystack/modeling/model/adaptive_model.py index a8db8f24e..fccef231a 100644 --- a/haystack/modeling/model/adaptive_model.py +++ b/haystack/modeling/model/adaptive_model.py @@ -13,7 +13,12 @@ from transformers import AutoConfig, AutoModelForQuestionAnswering from transformers.convert_graph_to_onnx import convert, quantize as quantize_model from haystack.modeling.data_handler.processor import Processor -from haystack.modeling.model.language_model import get_language_model, LanguageModel +from haystack.modeling.model.language_model import ( + get_language_model, + LanguageModel, + _get_model_type, + capitalize_model_type, +) from haystack.modeling.model.prediction_head import PredictionHead, QuestionAnsweringHead from haystack.utils.experiment_tracking import Tracker as tracker @@ -626,8 +631,8 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel): https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained :return: None. """ - language_model_class = LanguageModel.get_language_model_class(model_name) - if language_model_class not in ["Bert", "Roberta", "XLMRoberta"]: + model_type = capitalize_model_type(_get_model_type(model_name)) # type: ignore + if model_type not in ["Bert", "Roberta", "XMLRoberta"]: raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.") task_type_to_pipeline_map = {"question_answering": "question-answering"} @@ -638,7 +643,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel): model=model_name, output=output_path / "model.onnx", opset=opset_version, - use_external_format=True if language_model_class == "XLMRoberta" else False, + use_external_format=True if model_type == "XMLRoberta" else False, use_auth_token=use_auth_token, ) @@ -661,7 +666,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel): onnx_model_config = { "task_type": task_type, "onnx_opset_version": opset_version, - "language_model_class": language_model_class, + "language_model_class": model_type, "language": model.language_model.language, } with open(output_path / "onnx_model_config.json", "w") as f: diff --git a/test/nodes/test_reader.py b/test/nodes/test_reader.py index cce30def6..646a6afc9 100644 --- a/test/nodes/test_reader.py +++ b/test/nodes/test_reader.py @@ -1,4 +1,6 @@ import math +import os +from pathlib import Path import pytest from haystack.modeling.data_handler.inputs import QAInput, Question @@ -277,3 +279,16 @@ When beer is distilled, the resulting liquor is a form of whisky.[12] assert answer.score == qa_cand.confidence else: assert answer.score == qa_cand.score + + +@pytest.mark.parametrize("model_name", ["deepset/roberta-base-squad2", "deepset/bert-base-uncased-squad2"]) +def test_farm_reader_onnx_conversion_and_inference(model_name, tmpdir, docs): + FARMReader.convert_to_onnx(model_name=model_name, output_path=Path(tmpdir, "onnx")) + assert os.path.exists(Path(tmpdir, "onnx", "model.onnx")) + assert os.path.exists(Path(tmpdir, "onnx", "processor_config.json")) + assert os.path.exists(Path(tmpdir, "onnx", "onnx_model_config.json")) + assert os.path.exists(Path(tmpdir, "onnx", "language_model_config.json")) + + reader = FARMReader(str(Path(tmpdir, "onnx"))) + result = reader.predict(query="Where does Paul live?", documents=[docs[0]]) + assert result["answers"][0].answer == "New York" diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index ffff02276..076a70dce 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -1,5 +1,4 @@ import logging -import time from math import isclose import numpy as np @@ -7,7 +6,6 @@ import pandas as pd from haystack.document_stores.base import BaseDocumentStore from haystack.document_stores.memory import InMemoryDocumentStore import pytest -from pathlib import Path from elasticsearch import Elasticsearch from haystack.document_stores import WeaviateDocumentStore @@ -16,12 +14,7 @@ from haystack.schema import Document from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore from haystack.document_stores.faiss import FAISSDocumentStore from haystack.document_stores import MilvusDocumentStore -from haystack.nodes.retriever.dense import ( - DensePassageRetriever, - EmbeddingRetriever, - TableTextRetriever, - MultihopEmbeddingRetriever, -) +from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast