mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
fix: ONNX FARMReader model conversion is broken (#3211)
This commit is contained in:
parent
b579b9d54a
commit
9582a423a2
@ -13,7 +13,12 @@ from transformers import AutoConfig, AutoModelForQuestionAnswering
|
|||||||
from transformers.convert_graph_to_onnx import convert, quantize as quantize_model
|
from transformers.convert_graph_to_onnx import convert, quantize as quantize_model
|
||||||
|
|
||||||
from haystack.modeling.data_handler.processor import Processor
|
from haystack.modeling.data_handler.processor import Processor
|
||||||
from haystack.modeling.model.language_model import get_language_model, LanguageModel
|
from haystack.modeling.model.language_model import (
|
||||||
|
get_language_model,
|
||||||
|
LanguageModel,
|
||||||
|
_get_model_type,
|
||||||
|
capitalize_model_type,
|
||||||
|
)
|
||||||
from haystack.modeling.model.prediction_head import PredictionHead, QuestionAnsweringHead
|
from haystack.modeling.model.prediction_head import PredictionHead, QuestionAnsweringHead
|
||||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||||
|
|
||||||
@ -626,8 +631,8 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
|||||||
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||||
:return: None.
|
:return: None.
|
||||||
"""
|
"""
|
||||||
language_model_class = LanguageModel.get_language_model_class(model_name)
|
model_type = capitalize_model_type(_get_model_type(model_name)) # type: ignore
|
||||||
if language_model_class not in ["Bert", "Roberta", "XLMRoberta"]:
|
if model_type not in ["Bert", "Roberta", "XMLRoberta"]:
|
||||||
raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.")
|
raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.")
|
||||||
|
|
||||||
task_type_to_pipeline_map = {"question_answering": "question-answering"}
|
task_type_to_pipeline_map = {"question_answering": "question-answering"}
|
||||||
@ -638,7 +643,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
output=output_path / "model.onnx",
|
output=output_path / "model.onnx",
|
||||||
opset=opset_version,
|
opset=opset_version,
|
||||||
use_external_format=True if language_model_class == "XLMRoberta" else False,
|
use_external_format=True if model_type == "XMLRoberta" else False,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -661,7 +666,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
|||||||
onnx_model_config = {
|
onnx_model_config = {
|
||||||
"task_type": task_type,
|
"task_type": task_type,
|
||||||
"onnx_opset_version": opset_version,
|
"onnx_opset_version": opset_version,
|
||||||
"language_model_class": language_model_class,
|
"language_model_class": model_type,
|
||||||
"language": model.language_model.language,
|
"language": model.language_model.language,
|
||||||
}
|
}
|
||||||
with open(output_path / "onnx_model_config.json", "w") as f:
|
with open(output_path / "onnx_model_config.json", "w") as f:
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from haystack.modeling.data_handler.inputs import QAInput, Question
|
from haystack.modeling.data_handler.inputs import QAInput, Question
|
||||||
@ -277,3 +279,16 @@ When beer is distilled, the resulting liquor is a form of whisky.[12]
|
|||||||
assert answer.score == qa_cand.confidence
|
assert answer.score == qa_cand.confidence
|
||||||
else:
|
else:
|
||||||
assert answer.score == qa_cand.score
|
assert answer.score == qa_cand.score
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", ["deepset/roberta-base-squad2", "deepset/bert-base-uncased-squad2"])
|
||||||
|
def test_farm_reader_onnx_conversion_and_inference(model_name, tmpdir, docs):
|
||||||
|
FARMReader.convert_to_onnx(model_name=model_name, output_path=Path(tmpdir, "onnx"))
|
||||||
|
assert os.path.exists(Path(tmpdir, "onnx", "model.onnx"))
|
||||||
|
assert os.path.exists(Path(tmpdir, "onnx", "processor_config.json"))
|
||||||
|
assert os.path.exists(Path(tmpdir, "onnx", "onnx_model_config.json"))
|
||||||
|
assert os.path.exists(Path(tmpdir, "onnx", "language_model_config.json"))
|
||||||
|
|
||||||
|
reader = FARMReader(str(Path(tmpdir, "onnx")))
|
||||||
|
result = reader.predict(query="Where does Paul live?", documents=[docs[0]])
|
||||||
|
assert result["answers"][0].answer == "New York"
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from math import isclose
|
from math import isclose
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -7,7 +6,6 @@ import pandas as pd
|
|||||||
from haystack.document_stores.base import BaseDocumentStore
|
from haystack.document_stores.base import BaseDocumentStore
|
||||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||||
import pytest
|
import pytest
|
||||||
from pathlib import Path
|
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
from haystack.document_stores import WeaviateDocumentStore
|
from haystack.document_stores import WeaviateDocumentStore
|
||||||
@ -16,12 +14,7 @@ from haystack.schema import Document
|
|||||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||||
from haystack.document_stores.faiss import FAISSDocumentStore
|
from haystack.document_stores.faiss import FAISSDocumentStore
|
||||||
from haystack.document_stores import MilvusDocumentStore
|
from haystack.document_stores import MilvusDocumentStore
|
||||||
from haystack.nodes.retriever.dense import (
|
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever
|
||||||
DensePassageRetriever,
|
|
||||||
EmbeddingRetriever,
|
|
||||||
TableTextRetriever,
|
|
||||||
MultihopEmbeddingRetriever,
|
|
||||||
)
|
|
||||||
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
|
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
|
||||||
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user