Fixed bug in onnx converter for XLMRoberta architecture (#3470)

This commit is contained in:
Sebastian 2022-10-28 15:35:53 +02:00 committed by GitHub
parent 9f4a9a76a3
commit 384663981d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -632,7 +632,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
:return: None.
"""
model_type = capitalize_model_type(_get_model_type(model_name)) # type: ignore
if model_type not in ["Bert", "Roberta", "XMLRoberta"]:
if model_type not in ["Bert", "Roberta", "XLMRoberta"]:
raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.")
task_type_to_pipeline_map = {"question_answering": "question-answering"}
@ -643,7 +643,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
model=model_name,
output=output_path / "model.onnx",
opset=opset_version,
use_external_format=True if model_type == "XMLRoberta" else False,
use_external_format=True if model_type == "XLMRoberta" else False,
use_auth_token=use_auth_token,
)

View File

@ -280,7 +280,10 @@ When beer is distilled, the resulting liquor is a form of whisky.[12]
assert answer.score == qa_cand.score
@pytest.mark.parametrize("model_name", ["deepset/tinyroberta-squad2", "deepset/bert-medium-squad2-distilled"])
@pytest.mark.parametrize(
"model_name",
["deepset/tinyroberta-squad2", "deepset/bert-medium-squad2-distilled", "deepset/xlm-roberta-base-squad2-distilled"],
)
def test_farm_reader_onnx_conversion_and_inference(model_name, tmpdir, docs):
FARMReader.convert_to_onnx(model_name=model_name, output_path=Path(tmpdir, "onnx"))
assert os.path.exists(Path(tmpdir, "onnx", "model.onnx"))