add version control for export and modify hpi config (#14513)

This commit is contained in:
zhangyubo0722 2025-01-08 17:29:52 +08:00 committed by GitHub
parent a6b96bbfb1
commit bf2b73f0f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -53,7 +53,7 @@ def dump_infer_config(config, path, logger):
}
elif arch_config["model_type"] == "det":
common_dynamic_shapes = {
"x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]]
"x": [[1, 3, 160, 160], [1, 3, 640, 640], [1, 3, 1280, 1280]]
}
elif arch_config["algorithm"] == "SLANet":
common_dynamic_shapes = {
@ -64,11 +64,17 @@ def dump_infer_config(config, path, logger):
"x": [[1, 3, 224, 224], [1, 3, 448, 448], [8, 3, 1280, 1280]]
}
elif arch_config["algorithm"] == "UniMERNet":
common_dynamic_shapes = {"x": [[1, 3, 192, 672]]}
common_dynamic_shapes = {
"x": [[1, 3, 192, 672], [1, 3, 192, 672], [8, 3, 192, 672]]
}
elif arch_config["algorithm"] == "PP-FormulaNet-L":
common_dynamic_shapes = {"x": [[1, 3, 768, 768]]}
common_dynamic_shapes = {
"x": [[1, 3, 768, 768], [1, 3, 768, 768], [8, 3, 768, 768]]
}
elif arch_config["algorithm"] == "PP-FormulaNet-S":
common_dynamic_shapes = {"x": [[1, 3, 384, 384]]}
common_dynamic_shapes = {
"x": [[1, 3, 384, 384], [1, 3, 384, 384], [8, 3, 384, 384]]
}
else:
common_dynamic_shapes = None
@ -345,17 +351,22 @@ def export_single_model(
ModuleNotFoundError
): # Encryption is not needed if the module cannot be imported
print("Skipping import of the encryption module")
paddle_version = version.parse(paddle.__version__)
if config["Global"].get("export_with_pir", False):
paddle_version = version.parse(paddle.__version__)
assert (
paddle_version >= version.parse("3.0.0b2")
or paddle_version == version.parse("0.0.0")
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]
paddle.jit.save(model, save_path)
else:
model.forward.rollback()
with paddle.pir_utils.OldIrGuard():
model = dynamic_to_static(model, arch_config, logger, input_shape)
if paddle_version >= version.parse(
"3.0.0b2"
) or paddle_version == version.parse("0.0.0"):
model.forward.rollback()
with paddle.pir_utils.OldIrGuard():
model = dynamic_to_static(model, arch_config, logger, input_shape)
paddle.jit.save(model, save_path)
else:
paddle.jit.save(model, save_path)
else:
quanter.save_quantized_model(model, save_path)