mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-26 21:24:27 +00:00
add version control for export and modify hpi config (#14513)
This commit is contained in:
parent
a6b96bbfb1
commit
bf2b73f0f0
@ -53,7 +53,7 @@ def dump_infer_config(config, path, logger):
|
||||
}
|
||||
elif arch_config["model_type"] == "det":
|
||||
common_dynamic_shapes = {
|
||||
"x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]]
|
||||
"x": [[1, 3, 160, 160], [1, 3, 640, 640], [1, 3, 1280, 1280]]
|
||||
}
|
||||
elif arch_config["algorithm"] == "SLANet":
|
||||
common_dynamic_shapes = {
|
||||
@ -64,11 +64,17 @@ def dump_infer_config(config, path, logger):
|
||||
"x": [[1, 3, 224, 224], [1, 3, 448, 448], [8, 3, 1280, 1280]]
|
||||
}
|
||||
elif arch_config["algorithm"] == "UniMERNet":
|
||||
common_dynamic_shapes = {"x": [[1, 3, 192, 672]]}
|
||||
common_dynamic_shapes = {
|
||||
"x": [[1, 3, 192, 672], [1, 3, 192, 672], [8, 3, 192, 672]]
|
||||
}
|
||||
elif arch_config["algorithm"] == "PP-FormulaNet-L":
|
||||
common_dynamic_shapes = {"x": [[1, 3, 768, 768]]}
|
||||
common_dynamic_shapes = {
|
||||
"x": [[1, 3, 768, 768], [1, 3, 768, 768], [8, 3, 768, 768]]
|
||||
}
|
||||
elif arch_config["algorithm"] == "PP-FormulaNet-S":
|
||||
common_dynamic_shapes = {"x": [[1, 3, 384, 384]]}
|
||||
common_dynamic_shapes = {
|
||||
"x": [[1, 3, 384, 384], [1, 3, 384, 384], [8, 3, 384, 384]]
|
||||
}
|
||||
else:
|
||||
common_dynamic_shapes = None
|
||||
|
||||
@ -345,17 +351,22 @@ def export_single_model(
|
||||
ModuleNotFoundError
|
||||
): # Encryption is not needed if the module cannot be imported
|
||||
print("Skipping import of the encryption module")
|
||||
paddle_version = version.parse(paddle.__version__)
|
||||
if config["Global"].get("export_with_pir", False):
|
||||
paddle_version = version.parse(paddle.__version__)
|
||||
assert (
|
||||
paddle_version >= version.parse("3.0.0b2")
|
||||
or paddle_version == version.parse("0.0.0")
|
||||
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]
|
||||
paddle.jit.save(model, save_path)
|
||||
else:
|
||||
model.forward.rollback()
|
||||
with paddle.pir_utils.OldIrGuard():
|
||||
model = dynamic_to_static(model, arch_config, logger, input_shape)
|
||||
if paddle_version >= version.parse(
|
||||
"3.0.0b2"
|
||||
) or paddle_version == version.parse("0.0.0"):
|
||||
model.forward.rollback()
|
||||
with paddle.pir_utils.OldIrGuard():
|
||||
model = dynamic_to_static(model, arch_config, logger, input_shape)
|
||||
paddle.jit.save(model, save_path)
|
||||
else:
|
||||
paddle.jit.save(model, save_path)
|
||||
else:
|
||||
quanter.save_quantized_model(model, save_path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user