mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-30 17:29:29 +00:00
fix: ONNX FARMReader model conversion is broken (#3211)
This commit is contained in:
parent
b579b9d54a
commit
9582a423a2
@ -13,7 +13,12 @@ from transformers import AutoConfig, AutoModelForQuestionAnswering
|
||||
from transformers.convert_graph_to_onnx import convert, quantize as quantize_model
|
||||
|
||||
from haystack.modeling.data_handler.processor import Processor
|
||||
from haystack.modeling.model.language_model import get_language_model, LanguageModel
|
||||
from haystack.modeling.model.language_model import (
|
||||
get_language_model,
|
||||
LanguageModel,
|
||||
_get_model_type,
|
||||
capitalize_model_type,
|
||||
)
|
||||
from haystack.modeling.model.prediction_head import PredictionHead, QuestionAnsweringHead
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
|
||||
@ -626,8 +631,8 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
:return: None.
|
||||
"""
|
||||
language_model_class = LanguageModel.get_language_model_class(model_name)
|
||||
if language_model_class not in ["Bert", "Roberta", "XLMRoberta"]:
|
||||
model_type = capitalize_model_type(_get_model_type(model_name)) # type: ignore
|
||||
if model_type not in ["Bert", "Roberta", "XMLRoberta"]:
|
||||
raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.")
|
||||
|
||||
task_type_to_pipeline_map = {"question_answering": "question-answering"}
|
||||
@ -638,7 +643,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
model=model_name,
|
||||
output=output_path / "model.onnx",
|
||||
opset=opset_version,
|
||||
use_external_format=True if language_model_class == "XLMRoberta" else False,
|
||||
use_external_format=True if model_type == "XMLRoberta" else False,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
@ -661,7 +666,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
onnx_model_config = {
|
||||
"task_type": task_type,
|
||||
"onnx_opset_version": opset_version,
|
||||
"language_model_class": language_model_class,
|
||||
"language_model_class": model_type,
|
||||
"language": model.language_model.language,
|
||||
}
|
||||
with open(output_path / "onnx_model_config.json", "w") as f:
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from haystack.modeling.data_handler.inputs import QAInput, Question
|
||||
@ -277,3 +279,16 @@ When beer is distilled, the resulting liquor is a form of whisky.[12]
|
||||
assert answer.score == qa_cand.confidence
|
||||
else:
|
||||
assert answer.score == qa_cand.score
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["deepset/roberta-base-squad2", "deepset/bert-base-uncased-squad2"])
|
||||
def test_farm_reader_onnx_conversion_and_inference(model_name, tmpdir, docs):
|
||||
FARMReader.convert_to_onnx(model_name=model_name, output_path=Path(tmpdir, "onnx"))
|
||||
assert os.path.exists(Path(tmpdir, "onnx", "model.onnx"))
|
||||
assert os.path.exists(Path(tmpdir, "onnx", "processor_config.json"))
|
||||
assert os.path.exists(Path(tmpdir, "onnx", "onnx_model_config.json"))
|
||||
assert os.path.exists(Path(tmpdir, "onnx", "language_model_config.json"))
|
||||
|
||||
reader = FARMReader(str(Path(tmpdir, "onnx")))
|
||||
result = reader.predict(query="Where does Paul live?", documents=[docs[0]])
|
||||
assert result["answers"][0].answer == "New York"
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
from math import isclose
|
||||
|
||||
import numpy as np
|
||||
@ -7,7 +6,6 @@ import pandas as pd
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from haystack.document_stores import WeaviateDocumentStore
|
||||
@ -16,12 +14,7 @@ from haystack.schema import Document
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.document_stores.faiss import FAISSDocumentStore
|
||||
from haystack.document_stores import MilvusDocumentStore
|
||||
from haystack.nodes.retriever.dense import (
|
||||
DensePassageRetriever,
|
||||
EmbeddingRetriever,
|
||||
TableTextRetriever,
|
||||
MultihopEmbeddingRetriever,
|
||||
)
|
||||
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever
|
||||
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
|
||||
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user