fix: ONNX FARMReader model conversion is broken (#3211)

This commit is contained in:
Vladimir Blagojevic 2022-09-26 15:18:12 +02:00 committed by GitHub
parent b579b9d54a
commit 9582a423a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 13 deletions

View File

@ -13,7 +13,12 @@ from transformers import AutoConfig, AutoModelForQuestionAnswering
from transformers.convert_graph_to_onnx import convert, quantize as quantize_model
from haystack.modeling.data_handler.processor import Processor
from haystack.modeling.model.language_model import get_language_model, LanguageModel
from haystack.modeling.model.language_model import (
get_language_model,
LanguageModel,
_get_model_type,
capitalize_model_type,
)
from haystack.modeling.model.prediction_head import PredictionHead, QuestionAnsweringHead
from haystack.utils.experiment_tracking import Tracker as tracker
@ -626,8 +631,8 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
:return: None.
"""
language_model_class = LanguageModel.get_language_model_class(model_name)
if language_model_class not in ["Bert", "Roberta", "XLMRoberta"]:
model_type = capitalize_model_type(_get_model_type(model_name)) # type: ignore
if model_type not in ["Bert", "Roberta", "XMLRoberta"]:
raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.")
task_type_to_pipeline_map = {"question_answering": "question-answering"}
@ -638,7 +643,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
model=model_name,
output=output_path / "model.onnx",
opset=opset_version,
use_external_format=True if language_model_class == "XLMRoberta" else False,
use_external_format=True if model_type == "XMLRoberta" else False,
use_auth_token=use_auth_token,
)
@ -661,7 +666,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
onnx_model_config = {
"task_type": task_type,
"onnx_opset_version": opset_version,
"language_model_class": language_model_class,
"language_model_class": model_type,
"language": model.language_model.language,
}
with open(output_path / "onnx_model_config.json", "w") as f:

View File

@ -1,4 +1,6 @@
import math
import os
from pathlib import Path
import pytest
from haystack.modeling.data_handler.inputs import QAInput, Question
@ -277,3 +279,16 @@ When beer is distilled, the resulting liquor is a form of whisky.[12]
assert answer.score == qa_cand.confidence
else:
assert answer.score == qa_cand.score
@pytest.mark.parametrize("model_name", ["deepset/roberta-base-squad2", "deepset/bert-base-uncased-squad2"])
def test_farm_reader_onnx_conversion_and_inference(model_name, tmpdir, docs):
FARMReader.convert_to_onnx(model_name=model_name, output_path=Path(tmpdir, "onnx"))
assert os.path.exists(Path(tmpdir, "onnx", "model.onnx"))
assert os.path.exists(Path(tmpdir, "onnx", "processor_config.json"))
assert os.path.exists(Path(tmpdir, "onnx", "onnx_model_config.json"))
assert os.path.exists(Path(tmpdir, "onnx", "language_model_config.json"))
reader = FARMReader(str(Path(tmpdir, "onnx")))
result = reader.predict(query="Where does Paul live?", documents=[docs[0]])
assert result["answers"][0].answer == "New York"

View File

@ -1,5 +1,4 @@
import logging
import time
from math import isclose
import numpy as np
@ -7,7 +6,6 @@ import pandas as pd
from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.memory import InMemoryDocumentStore
import pytest
from pathlib import Path
from elasticsearch import Elasticsearch
from haystack.document_stores import WeaviateDocumentStore
@ -16,12 +14,7 @@ from haystack.schema import Document
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.document_stores import MilvusDocumentStore
from haystack.nodes.retriever.dense import (
DensePassageRetriever,
EmbeddingRetriever,
TableTextRetriever,
MultihopEmbeddingRetriever,
)
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast