From dbf35bb71472b189405ee061a23bb9b9327d9f62 Mon Sep 17 00:00:00 2001 From: zhoujun Date: Thu, 17 Aug 2023 18:26:20 +0800 Subject: [PATCH] onnxruntime support gpu (#10668) * Update ch_PP-OCRv3_rec.yml * Update ch_PP-OCRv3_rec_distillation.yml * Update en_PP-OCRv3_rec.yml * Update arabic_PP-OCRv3_rec.yml * Update chinese_cht_PP-OCRv3_rec.yml * Update cyrillic_PP-OCRv3_rec.yml * Update devanagari_PP-OCRv3_rec.yml * Update japan_PP-OCRv3_rec.yml * Update ka_PP-OCRv3_rec.yml * Update korean_PP-OCRv3_rec.yml * Update latin_PP-OCRv3_rec.yml * Update ta_PP-OCRv3_rec.yml * Update te_PP-OCRv3_rec.yml * Update utility.py --- tools/infer/utility.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 4b58cb4ef3..fcd8ba7f4d 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -187,7 +187,10 @@ def create_predictor(args, mode, logger): if not os.path.exists(model_file_path): raise ValueError("not find model file path {}".format( model_file_path)) - sess = ort.InferenceSession(model_file_path) + if args.use_gpu: + sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider']) + else: + sess = ort.InferenceSession(model_file_path) return sess, sess.get_inputs()[0], None, None else: