mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-12 08:03:34 +00:00
onnxruntime support gpu (#10668)
* Update ch_PP-OCRv3_rec.yml * Update ch_PP-OCRv3_rec_distillation.yml * Update en_PP-OCRv3_rec.yml * Update arabic_PP-OCRv3_rec.yml * Update chinese_cht_PP-OCRv3_rec.yml * Update cyrillic_PP-OCRv3_rec.yml * Update devanagari_PP-OCRv3_rec.yml * Update japan_PP-OCRv3_rec.yml * Update ka_PP-OCRv3_rec.yml * Update korean_PP-OCRv3_rec.yml * Update latin_PP-OCRv3_rec.yml * Update ta_PP-OCRv3_rec.yml * Update te_PP-OCRv3_rec.yml * Update utility.py
This commit is contained in:
parent
8f010ecf1e
commit
dbf35bb714
@ -187,6 +187,9 @@ def create_predictor(args, mode, logger):
|
|||||||
if not os.path.exists(model_file_path):
|
if not os.path.exists(model_file_path):
|
||||||
raise ValueError("not find model file path {}".format(
|
raise ValueError("not find model file path {}".format(
|
||||||
model_file_path))
|
model_file_path))
|
||||||
|
if args.use_gpu:
|
||||||
|
sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
|
||||||
|
else:
|
||||||
sess = ort.InferenceSession(model_file_path)
|
sess = ort.InferenceSession(model_file_path)
|
||||||
return sess, sess.get_inputs()[0], None, None
|
return sess, sess.get_inputs()[0], None, None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user