mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 18:59:28 +00:00
Fixed bug in onnx converter for XLMRoberta architecture (#3470)
This commit is contained in:
parent
9f4a9a76a3
commit
384663981d
@ -632,7 +632,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
:return: None.
|
||||
"""
|
||||
model_type = capitalize_model_type(_get_model_type(model_name)) # type: ignore
|
||||
if model_type not in ["Bert", "Roberta", "XMLRoberta"]:
|
||||
if model_type not in ["Bert", "Roberta", "XLMRoberta"]:
|
||||
raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.")
|
||||
|
||||
task_type_to_pipeline_map = {"question_answering": "question-answering"}
|
||||
@ -643,7 +643,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
model=model_name,
|
||||
output=output_path / "model.onnx",
|
||||
opset=opset_version,
|
||||
use_external_format=True if model_type == "XMLRoberta" else False,
|
||||
use_external_format=True if model_type == "XLMRoberta" else False,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
|
||||
@ -280,7 +280,10 @@ When beer is distilled, the resulting liquor is a form of whisky.[12]
|
||||
assert answer.score == qa_cand.score
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["deepset/tinyroberta-squad2", "deepset/bert-medium-squad2-distilled"])
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
["deepset/tinyroberta-squad2", "deepset/bert-medium-squad2-distilled", "deepset/xlm-roberta-base-squad2-distilled"],
|
||||
)
|
||||
def test_farm_reader_onnx_conversion_and_inference(model_name, tmpdir, docs):
|
||||
FARMReader.convert_to_onnx(model_name=model_name, output_path=Path(tmpdir, "onnx"))
|
||||
assert os.path.exists(Path(tmpdir, "onnx", "model.onnx"))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user