mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-07-25 18:00:07 +00:00
Merge pull request #6042 from andyjpaddle/dygraph
update cpp infer for rec
This commit is contained in:
commit
a9c0cab75d
@ -46,6 +46,8 @@ DECLARE_int32(cls_batch_num);
|
|||||||
DECLARE_string(rec_model_dir);
|
DECLARE_string(rec_model_dir);
|
||||||
DECLARE_int32(rec_batch_num);
|
DECLARE_int32(rec_batch_num);
|
||||||
DECLARE_string(rec_char_dict_path);
|
DECLARE_string(rec_char_dict_path);
|
||||||
|
DECLARE_int32(rec_img_h);
|
||||||
|
DECLARE_int32(rec_img_w);
|
||||||
// forward related
|
// forward related
|
||||||
DECLARE_bool(det);
|
DECLARE_bool(det);
|
||||||
DECLARE_bool(rec);
|
DECLARE_bool(rec);
|
||||||
|
@ -45,7 +45,8 @@ public:
|
|||||||
const bool &use_mkldnn, const string &label_path,
|
const bool &use_mkldnn, const string &label_path,
|
||||||
const bool &use_tensorrt,
|
const bool &use_tensorrt,
|
||||||
const std::string &precision,
|
const std::string &precision,
|
||||||
const int &rec_batch_num) {
|
const int &rec_batch_num, const int &rec_img_h,
|
||||||
|
const int &rec_img_w) {
|
||||||
this->use_gpu_ = use_gpu;
|
this->use_gpu_ = use_gpu;
|
||||||
this->gpu_id_ = gpu_id;
|
this->gpu_id_ = gpu_id;
|
||||||
this->gpu_mem_ = gpu_mem;
|
this->gpu_mem_ = gpu_mem;
|
||||||
@ -54,6 +55,10 @@ public:
|
|||||||
this->use_tensorrt_ = use_tensorrt;
|
this->use_tensorrt_ = use_tensorrt;
|
||||||
this->precision_ = precision;
|
this->precision_ = precision;
|
||||||
this->rec_batch_num_ = rec_batch_num;
|
this->rec_batch_num_ = rec_batch_num;
|
||||||
|
this->rec_img_h_ = rec_img_h;
|
||||||
|
this->rec_img_w_ = rec_img_w;
|
||||||
|
std::vector<int> rec_image_shape = {3, rec_img_h, rec_img_w};
|
||||||
|
this->rec_image_shape_ = rec_image_shape;
|
||||||
|
|
||||||
this->label_list_ = Utility::ReadDict(label_path);
|
this->label_list_ = Utility::ReadDict(label_path);
|
||||||
this->label_list_.insert(this->label_list_.begin(),
|
this->label_list_.insert(this->label_list_.begin(),
|
||||||
@ -86,7 +91,9 @@ private:
|
|||||||
bool use_tensorrt_ = false;
|
bool use_tensorrt_ = false;
|
||||||
std::string precision_ = "fp32";
|
std::string precision_ = "fp32";
|
||||||
int rec_batch_num_ = 6;
|
int rec_batch_num_ = 6;
|
||||||
|
int rec_img_h_ = 32;
|
||||||
|
int rec_img_w_ = 320;
|
||||||
|
std::vector<int> rec_image_shape_ = {3, rec_img_h_, rec_img_w_};
|
||||||
// pre-process
|
// pre-process
|
||||||
CrnnResizeImg resize_op_;
|
CrnnResizeImg resize_op_;
|
||||||
Normalize normalize_op_;
|
Normalize normalize_op_;
|
||||||
|
@ -323,6 +323,8 @@ More parameters are as follows,
|
|||||||
|rec_model_dir|string|-|Address of recognition inference model|
|
|rec_model_dir|string|-|Address of recognition inference model|
|
||||||
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|dictionary file|
|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|dictionary file|
|
||||||
|rec_batch_num|int|6|batch size of recognition|
|
|rec_batch_num|int|6|batch size of recognition|
|
||||||
|
|rec_img_h|int|32|image height of recognition|
|
||||||
|
|rec_img_w|int|320|image width of recognition|
|
||||||
|
|
||||||
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`.
|
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`.
|
||||||
|
|
||||||
|
@ -336,6 +336,8 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|
|||||||
|rec_model_dir|string|-|识别模型inference model地址|
|
|rec_model_dir|string|-|识别模型inference model地址|
|
||||||
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件|
|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件|
|
||||||
|rec_batch_num|int|6|识别模型batchsize|
|
|rec_batch_num|int|6|识别模型batchsize|
|
||||||
|
|rec_img_h|int|32|识别模型输入图像高度|
|
||||||
|
|rec_img_w|int|320|识别模型输入图像宽度|
|
||||||
|
|
||||||
|
|
||||||
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
|
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
|
||||||
|
@ -47,6 +47,8 @@ DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
|
|||||||
DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
|
DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
|
||||||
DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
|
DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
|
||||||
"Path of dictionary.");
|
"Path of dictionary.");
|
||||||
|
DEFINE_int32(rec_img_h, 32, "rec image height");
|
||||||
|
DEFINE_int32(rec_img_w, 320, "rec image width");
|
||||||
|
|
||||||
// ocr forward related
|
// ocr forward related
|
||||||
DEFINE_bool(det, true, "Whether use det in forward.");
|
DEFINE_bool(det, true, "Whether use det in forward.");
|
||||||
|
@ -39,7 +39,9 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
|
|||||||
auto preprocess_start = std::chrono::steady_clock::now();
|
auto preprocess_start = std::chrono::steady_clock::now();
|
||||||
int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_);
|
int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_);
|
||||||
int batch_num = end_img_no - beg_img_no;
|
int batch_num = end_img_no - beg_img_no;
|
||||||
float max_wh_ratio = 0;
|
int imgH = this->rec_image_shape_[1];
|
||||||
|
int imgW = this->rec_image_shape_[2];
|
||||||
|
float max_wh_ratio = imgW * 1.0 / imgH;
|
||||||
for (int ino = beg_img_no; ino < end_img_no; ino++) {
|
for (int ino = beg_img_no; ino < end_img_no; ino++) {
|
||||||
int h = img_list[indices[ino]].rows;
|
int h = img_list[indices[ino]].rows;
|
||||||
int w = img_list[indices[ino]].cols;
|
int w = img_list[indices[ino]].cols;
|
||||||
@ -47,28 +49,28 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
|
|||||||
max_wh_ratio = max(max_wh_ratio, wh_ratio);
|
max_wh_ratio = max(max_wh_ratio, wh_ratio);
|
||||||
}
|
}
|
||||||
|
|
||||||
int batch_width = 0;
|
int batch_width = imgW;
|
||||||
std::vector<cv::Mat> norm_img_batch;
|
std::vector<cv::Mat> norm_img_batch;
|
||||||
for (int ino = beg_img_no; ino < end_img_no; ino++) {
|
for (int ino = beg_img_no; ino < end_img_no; ino++) {
|
||||||
cv::Mat srcimg;
|
cv::Mat srcimg;
|
||||||
img_list[indices[ino]].copyTo(srcimg);
|
img_list[indices[ino]].copyTo(srcimg);
|
||||||
cv::Mat resize_img;
|
cv::Mat resize_img;
|
||||||
this->resize_op_.Run(srcimg, resize_img, max_wh_ratio,
|
this->resize_op_.Run(srcimg, resize_img, max_wh_ratio,
|
||||||
this->use_tensorrt_);
|
this->use_tensorrt_, this->rec_image_shape_);
|
||||||
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||||
this->is_scale_);
|
this->is_scale_);
|
||||||
norm_img_batch.push_back(resize_img);
|
norm_img_batch.push_back(resize_img);
|
||||||
batch_width = max(resize_img.cols, batch_width);
|
batch_width = max(resize_img.cols, batch_width);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> input(batch_num * 3 * 32 * batch_width, 0.0f);
|
std::vector<float> input(batch_num * 3 * imgH * batch_width, 0.0f);
|
||||||
this->permute_op_.Run(norm_img_batch, input.data());
|
this->permute_op_.Run(norm_img_batch, input.data());
|
||||||
auto preprocess_end = std::chrono::steady_clock::now();
|
auto preprocess_end = std::chrono::steady_clock::now();
|
||||||
preprocess_diff += preprocess_end - preprocess_start;
|
preprocess_diff += preprocess_end - preprocess_start;
|
||||||
// Inference.
|
// Inference.
|
||||||
auto input_names = this->predictor_->GetInputNames();
|
auto input_names = this->predictor_->GetInputNames();
|
||||||
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
|
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
|
||||||
input_t->Reshape({batch_num, 3, 32, batch_width});
|
input_t->Reshape({batch_num, 3, imgH, batch_width});
|
||||||
auto inference_start = std::chrono::steady_clock::now();
|
auto inference_start = std::chrono::steady_clock::now();
|
||||||
input_t->CopyFromCpu(input.data());
|
input_t->CopyFromCpu(input.data());
|
||||||
this->predictor_->Run();
|
this->predictor_->Run();
|
||||||
@ -142,13 +144,14 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
|||||||
precision = paddle_infer::Config::Precision::kInt8;
|
precision = paddle_infer::Config::Precision::kInt8;
|
||||||
}
|
}
|
||||||
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
|
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
|
||||||
|
int imgH = this->rec_image_shape_[1];
|
||||||
|
int imgW = this->rec_image_shape_[2];
|
||||||
std::map<std::string, std::vector<int>> min_input_shape = {
|
std::map<std::string, std::vector<int>> min_input_shape = {
|
||||||
{"x", {1, 3, 32, 10}}, {"lstm_0.tmp_0", {10, 1, 96}}};
|
{"x", {1, 3, imgH, 10}}, {"lstm_0.tmp_0", {10, 1, 96}}};
|
||||||
std::map<std::string, std::vector<int>> max_input_shape = {
|
std::map<std::string, std::vector<int>> max_input_shape = {
|
||||||
{"x", {1, 3, 32, 2000}}, {"lstm_0.tmp_0", {1000, 1, 96}}};
|
{"x", {1, 3, imgH, 2000}}, {"lstm_0.tmp_0", {1000, 1, 96}}};
|
||||||
std::map<std::string, std::vector<int>> opt_input_shape = {
|
std::map<std::string, std::vector<int>> opt_input_shape = {
|
||||||
{"x", {1, 3, 32, 320}}, {"lstm_0.tmp_0", {25, 1, 96}}};
|
{"x", {1, 3, imgH, imgW}}, {"lstm_0.tmp_0", {25, 1, 96}}};
|
||||||
|
|
||||||
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
|
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
|
||||||
opt_input_shape);
|
opt_input_shape);
|
||||||
|
@ -39,7 +39,8 @@ PPOCR::PPOCR() {
|
|||||||
this->recognizer_ = new CRNNRecognizer(
|
this->recognizer_ = new CRNNRecognizer(
|
||||||
FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
|
FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
|
||||||
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_rec_char_dict_path,
|
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_rec_char_dict_path,
|
||||||
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num);
|
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num,
|
||||||
|
FLAGS_rec_img_h, FLAGS_rec_img_w);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -41,16 +41,17 @@ void Permute::Run(const cv::Mat *im, float *data) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PermuteBatch::Run(const std::vector<cv::Mat> imgs, float *data) {
|
void PermuteBatch::Run(const std::vector<cv::Mat> imgs, float *data) {
|
||||||
for (int j = 0; j < imgs.size(); j ++){
|
for (int j = 0; j < imgs.size(); j++) {
|
||||||
int rh = imgs[j].rows;
|
int rh = imgs[j].rows;
|
||||||
int rw = imgs[j].cols;
|
int rw = imgs[j].cols;
|
||||||
int rc = imgs[j].channels();
|
int rc = imgs[j].channels();
|
||||||
for (int i = 0; i < rc; ++i) {
|
for (int i = 0; i < rc; ++i) {
|
||||||
cv::extractChannel(imgs[j], cv::Mat(rh, rw, CV_32FC1, data + (j * rc + i) * rh * rw), i);
|
cv::extractChannel(
|
||||||
}
|
imgs[j], cv::Mat(rh, rw, CV_32FC1, data + (j * rc + i) * rh * rw), i);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
||||||
const std::vector<float> &scale, const bool is_scale) {
|
const std::vector<float> &scale, const bool is_scale) {
|
||||||
double e = 1.0;
|
double e = 1.0;
|
||||||
@ -101,8 +102,8 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
|
|||||||
imgC = rec_image_shape[0];
|
imgC = rec_image_shape[0];
|
||||||
imgH = rec_image_shape[1];
|
imgH = rec_image_shape[1];
|
||||||
imgW = rec_image_shape[2];
|
imgW = rec_image_shape[2];
|
||||||
|
|
||||||
imgW = int(32 * wh_ratio);
|
imgW = int(imgH * wh_ratio);
|
||||||
|
|
||||||
float ratio = float(img.cols) / float(img.rows);
|
float ratio = float(img.cols) / float(img.rows);
|
||||||
int resize_w, resize_h;
|
int resize_w, resize_h;
|
||||||
@ -111,7 +112,7 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
|
|||||||
resize_w = imgW;
|
resize_w = imgW;
|
||||||
else
|
else
|
||||||
resize_w = int(ceilf(imgH * ratio));
|
resize_w = int(ceilf(imgH * ratio));
|
||||||
|
|
||||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||||
cv::INTER_LINEAR);
|
cv::INTER_LINEAR);
|
||||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
|
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user