diff --git a/paddleocr/_cli.py b/paddleocr/_cli.py index 5d5b392f1d..3c644bfdd3 100644 --- a/paddleocr/_cli.py +++ b/paddleocr/_cli.py @@ -13,6 +13,8 @@ # limitations under the License. import argparse +import subprocess +import sys import warnings from ._models import ( @@ -79,12 +81,27 @@ def _register_models(subparsers): subparser.set_defaults(executor=subcommand_executor.execute_with_args) +def _register_install_hpi_deps_command(subparsers): + def _install_hpi_deps(args): + hpip = f"hpi-{args.variant}" + try: + subprocess.check_call(["paddlex", "--install", hpip]) + subprocess.check_call(["paddlex", "--install", "paddle2onnx"]) + except subprocess.CalledProcessError: + sys.exit("Failed to install dependencies") + + subparser = subparsers.add_parser("install_hpi_deps") + subparser.add_argument("variant", type=str, choices=["cpu", "gpu", "npu"]) + subparser.set_defaults(executor=_install_hpi_deps) + + def _parse_args(): parser = argparse.ArgumentParser(prog="paddleocr") parser.add_argument("--version", action="version", version=f"%(prog)s {version}") subparsers = parser.add_subparsers(dest="subcommand") _register_pipelines(subparsers) _register_models(subparsers) + _register_install_hpi_deps_command(subparsers) return parser.parse_args()