CPP: codes view and some minimum fix (compiler warning...) (#14635)

* CPP: minimum fix (compiler warning...)

* Codes Review
This commit is contained in:
nonwill 2025-02-08 11:09:24 +08:00 committed by GitHub
parent 2c0c4beb06
commit 02f106d668
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 124 additions and 112 deletions

View File

@ -61,23 +61,29 @@ public:
GetRotateCropImage(const cv::Mat &srcimage, GetRotateCropImage(const cv::Mat &srcimage,
const std::vector<std::vector<int>> &box) noexcept; const std::vector<std::vector<int>> &box) noexcept;
static std::vector<int> argsort(const std::vector<float> &array) noexcept; static std::vector<size_t> argsort(const std::vector<float> &array) noexcept;
static std::string basename(const std::string &filename) noexcept; static std::string basename(const std::string &filename) noexcept;
static bool PathExists(const std::string &path) noexcept; static bool PathExists(const char *path) noexcept;
static inline bool PathExists(const std::string &path) noexcept {
return PathExists(path.c_str());
}
static void CreateDir(const std::string &path) noexcept; static void CreateDir(const char *path) noexcept;
static inline void CreateDir(const std::string &path) noexcept {
CreateDir(path.c_str());
}
static void static void
print_result(const std::vector<OCRPredictResult> &ocr_result) noexcept; print_result(const std::vector<OCRPredictResult> &ocr_result) noexcept;
static cv::Mat crop_image(cv::Mat &img, static cv::Mat crop_image(const cv::Mat &img,
const std::vector<int> &area) noexcept; const std::vector<int> &area) noexcept;
static cv::Mat crop_image(cv::Mat &img, static cv::Mat crop_image(const cv::Mat &img,
const std::vector<float> &area) noexcept; const std::vector<float> &area) noexcept;
static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result) noexcept; static void sort_boxes(std::vector<OCRPredictResult> &ocr_result) noexcept;
static std::vector<int> static std::vector<int>
xyxyxyxy2xyxy(const std::vector<std::vector<int>> &box) noexcept; xyxyxyxy2xyxy(const std::vector<std::vector<int>> &box) noexcept;
@ -85,9 +91,11 @@ public:
static float fast_exp(float x) noexcept; static float fast_exp(float x) noexcept;
static std::vector<float> static std::vector<float>
activation_function_softmax(std::vector<float> &src) noexcept; activation_function_softmax(const std::vector<float> &src) noexcept;
static float iou(std::vector<int> &box1, std::vector<int> &box2) noexcept; static float iou(const std::vector<int> &box1,
static float iou(std::vector<float> &box1, std::vector<float> &box2) noexcept; const std::vector<int> &box2) noexcept;
static float iou(const std::vector<float> &box1,
const std::vector<float> &box2) noexcept;
private: private:
static bool comparison_box(const OCRPredictResult &result1, static bool comparison_box(const OCRPredictResult &result1,

View File

@ -2812,7 +2812,7 @@ bool Clipper::FixupIntersectionOrder() noexcept {
if (!EdgesAdjacent(*m_IntersectList[i])) { if (!EdgesAdjacent(*m_IntersectList[i])) {
size_t j = i + 1; size_t j = i + 1;
while (j < cnt && !EdgesAdjacent(*m_IntersectList[j])) while (j < cnt && !EdgesAdjacent(*m_IntersectList[j]))
j++; ++j;
if (j == cnt) if (j == cnt)
return false; return false;
std::swap(m_IntersectList[i], m_IntersectList[j]); std::swap(m_IntersectList[i], m_IntersectList[j]);
@ -3078,7 +3078,7 @@ void Clipper::BuildResult2(PolyTree &polytree) noexcept {
pn->Index = 0; pn->Index = 0;
pn->Contour.reserve(cnt); pn->Contour.reserve(cnt);
OutPt *op = outRec->Pts->Prev; OutPt *op = outRec->Pts->Prev;
for (int j = 0; j < cnt; j++) { for (int j = 0; j < cnt; ++j) {
pn->Contour.emplace_back(op->Pt); pn->Contour.emplace_back(op->Pt);
op = op->Prev; op = op->Prev;
} }
@ -3643,7 +3643,7 @@ void ClipperOffset::AddPath(const Path &path, JoinType joinType,
int j = 0, k = 0; int j = 0, k = 0;
for (int i = 1; i <= highI; ++i) for (int i = 1; i <= highI; ++i)
if (newNode->Contour[j] != path[i]) { if (newNode->Contour[j] != path[i]) {
j++; ++j;
newNode->Contour.emplace_back(path[i]); newNode->Contour.emplace_back(path[i]);
if (path[i].Y > newNode->Contour[k].Y || if (path[i].Y > newNode->Contour[k].Y ||
(path[i].Y == newNode->Contour[k].Y && (path[i].Y == newNode->Contour[k].Y &&
@ -3826,7 +3826,7 @@ void ClipperOffset::DoOffset(double delta) noexcept {
if (len == 1) { if (len == 1) {
if (node.m_jointype == jtRound) { if (node.m_jointype == jtRound) {
double X = 1.0, Y = 0.0; double X = 1.0, Y = 0.0;
for (cInt j = 1; j <= steps; j++) { for (cInt j = 1; j <= steps; ++j) {
m_destPoly.emplace_back(Round(m_srcPoly[0].X + X * delta), m_destPoly.emplace_back(Round(m_srcPoly[0].X + X * delta),
Round(m_srcPoly[0].Y + Y * delta)); Round(m_srcPoly[0].Y + Y * delta));
double X2 = X; double X2 = X;

View File

@ -137,7 +137,7 @@ void structure(std::vector<cv::String> &cv_all_img_names) {
std::vector<StructurePredictResult> structure_results = engine.structure( std::vector<StructurePredictResult> structure_results = engine.structure(
img, FLAGS_layout, FLAGS_table, FLAGS_det && FLAGS_rec); img, FLAGS_layout, FLAGS_table, FLAGS_det && FLAGS_rec);
for (int j = 0; j < structure_results.size(); j++) { for (size_t j = 0; j < structure_results.size(); ++j) {
std::cout << j << "\ttype: " << structure_results[j].type std::cout << j << "\ttype: " << structure_results[j].type
<< ", region: ["; << ", region: [";
std::cout << structure_results[j].box[0] << "," std::cout << structure_results[j].box[0] << ","

View File

@ -40,7 +40,7 @@ void Classifier::Run(const std::vector<cv::Mat> &img_list,
int batch_num = end_img_no - beg_img_no; int batch_num = end_img_no - beg_img_no;
// preprocess // preprocess
std::vector<cv::Mat> norm_img_batch; std::vector<cv::Mat> norm_img_batch;
for (int ino = beg_img_no; ino < end_img_no; ino++) { for (int ino = beg_img_no; ino < end_img_no; ++ino) {
cv::Mat srcimg; cv::Mat srcimg;
img_list[ino].copyTo(srcimg); img_list[ino].copyTo(srcimg);
cv::Mat resize_img; cv::Mat resize_img;
@ -87,7 +87,7 @@ void Classifier::Run(const std::vector<cv::Mat> &img_list,
// postprocess // postprocess
auto postprocess_start = std::chrono::steady_clock::now(); auto postprocess_start = std::chrono::steady_clock::now();
for (int batch_idx = 0; batch_idx < predict_shape[0]; batch_idx++) { for (int batch_idx = 0; batch_idx < predict_shape[0]; ++batch_idx) {
int label = int( int label = int(
Utility::argmax(&predict_batch[batch_idx * predict_shape[1]], Utility::argmax(&predict_batch[batch_idx * predict_shape[1]],
&predict_batch[(batch_idx + 1) * predict_shape[1]])); &predict_batch[(batch_idx + 1) * predict_shape[1]]));

View File

@ -32,22 +32,22 @@ void CRNNRecognizer::Run(const std::vector<cv::Mat> &img_list,
std::chrono::duration<float> postprocess_diff = std::chrono::duration<float> postprocess_diff =
std::chrono::duration<float>::zero(); std::chrono::duration<float>::zero();
int img_num = img_list.size(); size_t img_num = img_list.size();
std::vector<float> width_list; std::vector<float> width_list;
for (int i = 0; i < img_num; ++i) { for (size_t i = 0; i < img_num; ++i) {
width_list.emplace_back(float(img_list[i].cols) / img_list[i].rows); width_list.emplace_back(float(img_list[i].cols) / img_list[i].rows);
} }
std::vector<int> indices = std::move(Utility::argsort(width_list)); std::vector<size_t> indices = std::move(Utility::argsort(width_list));
for (int beg_img_no = 0; beg_img_no < img_num; for (size_t beg_img_no = 0; beg_img_no < img_num;
beg_img_no += this->rec_batch_num_) { beg_img_no += this->rec_batch_num_) {
auto preprocess_start = std::chrono::steady_clock::now(); auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = std::min(img_num, beg_img_no + this->rec_batch_num_); size_t end_img_no = std::min(img_num, beg_img_no + this->rec_batch_num_);
int batch_num = end_img_no - beg_img_no; int batch_num = end_img_no - beg_img_no;
int imgH = this->rec_image_shape_[1]; int imgH = this->rec_image_shape_[1];
int imgW = this->rec_image_shape_[2]; int imgW = this->rec_image_shape_[2];
float max_wh_ratio = imgW * 1.0 / imgH; float max_wh_ratio = imgW * 1.0 / imgH;
for (int ino = beg_img_no; ino < end_img_no; ino++) { for (size_t ino = beg_img_no; ino < end_img_no; ++ino) {
int h = img_list[indices[ino]].rows; int h = img_list[indices[ino]].rows;
int w = img_list[indices[ino]].cols; int w = img_list[indices[ino]].cols;
float wh_ratio = w * 1.0 / h; float wh_ratio = w * 1.0 / h;
@ -56,7 +56,7 @@ void CRNNRecognizer::Run(const std::vector<cv::Mat> &img_list,
int batch_width = imgW; int batch_width = imgW;
std::vector<cv::Mat> norm_img_batch; std::vector<cv::Mat> norm_img_batch;
for (int ino = beg_img_no; ino < end_img_no; ino++) { for (size_t ino = beg_img_no; ino < end_img_no; ++ino) {
cv::Mat srcimg; cv::Mat srcimg;
img_list[indices[ino]].copyTo(srcimg); img_list[indices[ino]].copyTo(srcimg);
cv::Mat resize_img; cv::Mat resize_img;
@ -85,8 +85,8 @@ void CRNNRecognizer::Run(const std::vector<cv::Mat> &img_list,
auto output_t = this->predictor_->GetOutputHandle(output_names[0]); auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
auto predict_shape = output_t->shape(); auto predict_shape = output_t->shape();
int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1, size_t out_num = std::accumulate(predict_shape.begin(), predict_shape.end(),
std::multiplies<int>()); 1, std::multiplies<int>());
predict_batch.resize(out_num); predict_batch.resize(out_num);
// predict_batch is the result of Last FC with softmax // predict_batch is the result of Last FC with softmax
output_t->CopyToCpu(predict_batch.data()); output_t->CopyToCpu(predict_batch.data());
@ -94,7 +94,7 @@ void CRNNRecognizer::Run(const std::vector<cv::Mat> &img_list,
inference_diff += inference_end - inference_start; inference_diff += inference_end - inference_start;
// ctc decode // ctc decode
auto postprocess_start = std::chrono::steady_clock::now(); auto postprocess_start = std::chrono::steady_clock::now();
for (int m = 0; m < predict_shape[0]; m++) { for (int m = 0; m < predict_shape[0]; ++m) {
std::string str_res; std::string str_res;
int argmax_idx; int argmax_idx;
int last_index = 0; int last_index = 0;
@ -102,7 +102,7 @@ void CRNNRecognizer::Run(const std::vector<cv::Mat> &img_list,
int count = 0; int count = 0;
float max_value = 0.0f; float max_value = 0.0f;
for (int n = 0; n < predict_shape[1]; n++) { for (int n = 0; n < predict_shape[1]; ++n) {
// get idx // get idx
argmax_idx = int(Utility::argmax( argmax_idx = int(Utility::argmax(
&predict_batch[(m * predict_shape[1] + n) * predict_shape[2]], &predict_batch[(m * predict_shape[1] + n) * predict_shape[2]],

View File

@ -65,7 +65,7 @@ PPOCR::ocr(const std::vector<cv::Mat> &img_list, bool det, bool rec,
ocr_result.resize(img_list.size()); ocr_result.resize(img_list.size());
if (cls && this->pri_->classifier_) { if (cls && this->pri_->classifier_) {
this->cls(img_list, ocr_result); this->cls(img_list, ocr_result);
for (int i = 0; i < img_list.size(); ++i) { for (size_t i = 0; i < img_list.size(); ++i) {
if (ocr_result[i].cls_label % 2 == 1 && if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->pri_->classifier_->cls_thresh) { ocr_result[i].cls_score > this->pri_->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1); cv::rotate(img_list[i], img_list[i], 1);
@ -75,11 +75,11 @@ PPOCR::ocr(const std::vector<cv::Mat> &img_list, bool det, bool rec,
if (rec) { if (rec) {
this->rec(img_list, ocr_result); this->rec(img_list, ocr_result);
} }
for (int i = 0; i < ocr_result.size(); ++i) { for (size_t i = 0; i < ocr_result.size(); ++i) {
ocr_results.emplace_back(1, std::move(ocr_result[i])); ocr_results.emplace_back(1, std::move(ocr_result[i]));
} }
} else { } else {
for (int i = 0; i < img_list.size(); ++i) { for (size_t i = 0; i < img_list.size(); ++i) {
std::vector<OCRPredictResult> ocr_result = std::vector<OCRPredictResult> ocr_result =
this->ocr(img_list[i], true, rec, cls); this->ocr(img_list[i], true, rec, cls);
ocr_results.emplace_back(std::move(ocr_result)); ocr_results.emplace_back(std::move(ocr_result));
@ -96,14 +96,14 @@ std::vector<OCRPredictResult> PPOCR::ocr(const cv::Mat &img, bool det, bool rec,
this->det(img, ocr_result); this->det(img, ocr_result);
// crop image // crop image
std::vector<cv::Mat> img_list; std::vector<cv::Mat> img_list;
for (int j = 0; j < ocr_result.size(); j++) { for (size_t j = 0; j < ocr_result.size(); ++j) {
cv::Mat crop_img = Utility::GetRotateCropImage(img, ocr_result[j].box); cv::Mat crop_img = Utility::GetRotateCropImage(img, ocr_result[j].box);
img_list.emplace_back(std::move(crop_img)); img_list.emplace_back(std::move(crop_img));
} }
// cls // cls
if (cls && this->pri_->classifier_) { if (cls && this->pri_->classifier_) {
this->cls(img_list, ocr_result); this->cls(img_list, ocr_result);
for (int i = 0; i < img_list.size(); ++i) { for (size_t i = 0; i < img_list.size(); ++i) {
if (ocr_result[i].cls_label % 2 == 1 && if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->pri_->classifier_->cls_thresh) { ocr_result[i].cls_score > this->pri_->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1); cv::rotate(img_list[i], img_list[i], 1);
@ -124,13 +124,13 @@ void PPOCR::det(const cv::Mat &img,
this->pri_->detector_->Run(img, boxes, det_times); this->pri_->detector_->Run(img, boxes, det_times);
for (int i = 0; i < boxes.size(); ++i) { for (size_t i = 0; i < boxes.size(); ++i) {
OCRPredictResult res; OCRPredictResult res;
res.box = std::move(boxes[i]); res.box = std::move(boxes[i]);
ocr_results.emplace_back(std::move(res)); ocr_results.emplace_back(std::move(res));
} }
// sort boex from top to bottom, from left to right // sort boex from top to bottom, from left to right
Utility::sorted_boxes(ocr_results); Utility::sort_boxes(ocr_results);
this->time_info_det[0] += det_times[0]; this->time_info_det[0] += det_times[0];
this->time_info_det[1] += det_times[1]; this->time_info_det[1] += det_times[1];
this->time_info_det[2] += det_times[2]; this->time_info_det[2] += det_times[2];
@ -143,7 +143,7 @@ void PPOCR::rec(const std::vector<cv::Mat> &img_list,
std::vector<double> rec_times; std::vector<double> rec_times;
this->pri_->recognizer_->Run(img_list, rec_texts, rec_text_scores, rec_times); this->pri_->recognizer_->Run(img_list, rec_texts, rec_text_scores, rec_times);
// output rec results // output rec results
for (int i = 0; i < rec_texts.size(); ++i) { for (size_t i = 0; i < rec_texts.size(); ++i) {
ocr_results[i].text = std::move(rec_texts[i]); ocr_results[i].text = std::move(rec_texts[i]);
ocr_results[i].score = rec_text_scores[i]; ocr_results[i].score = rec_text_scores[i];
} }
@ -159,7 +159,7 @@ void PPOCR::cls(const std::vector<cv::Mat> &img_list,
std::vector<double> cls_times; std::vector<double> cls_times;
this->pri_->classifier_->Run(img_list, cls_labels, cls_scores, cls_times); this->pri_->classifier_->Run(img_list, cls_labels, cls_scores, cls_times);
// output cls results // output cls results
for (int i = 0; i < cls_labels.size(); ++i) { for (size_t i = 0; i < cls_labels.size(); ++i) {
ocr_results[i].cls_label = cls_labels[i]; ocr_results[i].cls_label = cls_labels[i];
ocr_results[i].cls_score = cls_scores[i]; ocr_results[i].cls_score = cls_scores[i];
} }

View File

@ -64,7 +64,7 @@ PaddleStructure::structure(const cv::Mat &srcimg, bool layout, bool table,
structure_results.emplace_back(std::move(res)); structure_results.emplace_back(std::move(res));
} }
cv::Mat roi_img; cv::Mat roi_img;
for (int i = 0; i < structure_results.size(); ++i) { for (size_t i = 0; i < structure_results.size(); ++i) {
// crop image // crop image
roi_img = std::move(Utility::crop_image(img, structure_results[i].box)); roi_img = std::move(Utility::crop_image(img, structure_results[i].box));
if (structure_results[i].type == "table" && table) { if (structure_results[i].type == "table" && table) {
@ -108,13 +108,13 @@ void PaddleStructure::table(const cv::Mat &img,
std::vector<OCRPredictResult> ocr_result; std::vector<OCRPredictResult> ocr_result;
int expand_pixel = 3; int expand_pixel = 3;
for (int i = 0; i < img_list.size(); ++i) { for (size_t i = 0; i < img_list.size(); ++i) {
// det // det
this->det(img_list[i], ocr_result); this->det(img_list[i], ocr_result);
// crop image // crop image
std::vector<cv::Mat> rec_img_list; std::vector<cv::Mat> rec_img_list;
std::vector<int> ocr_box; std::vector<int> ocr_box;
for (int j = 0; j < ocr_result.size(); j++) { for (size_t j = 0; j < ocr_result.size(); ++j) {
ocr_box = std::move(Utility::xyxyxyxy2xyxy(ocr_result[j].box)); ocr_box = std::move(Utility::xyxyxyxy2xyxy(ocr_result[j].box));
ocr_box[0] = std::max(0, ocr_box[0] - expand_pixel); ocr_box[0] = std::max(0, ocr_box[0] - expand_pixel);
ocr_box[1] = std::max(0, ocr_box[1] - expand_pixel), ocr_box[1] = std::max(0, ocr_box[1] - expand_pixel),
@ -144,7 +144,7 @@ std::string PaddleStructure::rebuild_table(
std::vector<int> ocr_box; std::vector<int> ocr_box;
std::vector<int> structure_box; std::vector<int> structure_box;
for (int i = 0; i < ocr_result.size(); ++i) { for (size_t i = 0; i < ocr_result.size(); ++i) {
ocr_box = std::move(Utility::xyxyxyxy2xyxy(ocr_result[i].box)); ocr_box = std::move(Utility::xyxyxyxy2xyxy(ocr_result[i].box));
ocr_box[0] -= 1; ocr_box[0] -= 1;
ocr_box[1] -= 1; ocr_box[1] -= 1;
@ -152,8 +152,8 @@ std::string PaddleStructure::rebuild_table(
ocr_box[3] += 1; ocr_box[3] += 1;
std::vector<std::vector<float>> dis_list(structure_boxes.size(), std::vector<std::vector<float>> dis_list(structure_boxes.size(),
std::vector<float>(3, 100000.0)); std::vector<float>(3, 100000.0));
for (int j = 0; j < structure_boxes.size(); j++) { for (size_t j = 0; j < structure_boxes.size(); ++j) {
if (structure_boxes[i].size() == 8) { if (structure_boxes[j].size() == 8) {
structure_box = std::move(Utility::xyxyxyxy2xyxy(structure_boxes[j])); structure_box = std::move(Utility::xyxyxyxy2xyxy(structure_boxes[j]));
} else { } else {
structure_box = structure_boxes[j]; structure_box = structure_boxes[j];
@ -171,7 +171,7 @@ std::string PaddleStructure::rebuild_table(
// get pred html // get pred html
std::string html_str = ""; std::string html_str = "";
int td_tag_idx = 0; int td_tag_idx = 0;
for (int i = 0; i < structure_html_tags.size(); ++i) { for (size_t i = 0; i < structure_html_tags.size(); ++i) {
if (structure_html_tags[i].find("</td>") != std::string::npos) { if (structure_html_tags[i].find("</td>") != std::string::npos) {
if (structure_html_tags[i].find("<td></td>") != std::string::npos) { if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
html_str += "<td>"; html_str += "<td>";
@ -183,7 +183,7 @@ std::string PaddleStructure::rebuild_table(
b_with = true; b_with = true;
html_str += "<b>"; html_str += "<b>";
} }
for (int j = 0; j < matched[td_tag_idx].size(); j++) { for (size_t j = 0; j < matched[td_tag_idx].size(); ++j) {
std::string content = matched[td_tag_idx][j]; std::string content = matched[td_tag_idx][j];
if (matched[td_tag_idx].size() > 1) { if (matched[td_tag_idx].size() > 1) {
// remove blank, <b> and </b> // remove blank, <b> and </b>

View File

@ -57,8 +57,8 @@ DBPostProcessor::UnClip(const std::vector<std::vector<float>> &box,
std::vector<cv::Point2f> points; std::vector<cv::Point2f> points;
for (int j = 0; j < soln.size(); j++) { for (size_t j = 0; j < soln.size(); ++j) {
for (int i = 0; i < soln[soln.size() - 1].size(); ++i) { for (size_t i = 0; i < soln[soln.size() - 1].size(); ++i) {
points.emplace_back(soln[j][i].X, soln[j][i].Y); points.emplace_back(soln[j][i].X, soln[j][i].Y);
} }
} }
@ -173,7 +173,7 @@ float DBPostProcessor::PolygonScoreAcc(const std::vector<cv::Point> &contour,
int height = pred.rows; int height = pred.rows;
std::vector<float> box_x; std::vector<float> box_x;
std::vector<float> box_y; std::vector<float> box_y;
for (int i = 0; i < contour.size(); ++i) { for (size_t i = 0; i < contour.size(); ++i) {
box_x.emplace_back(contour[i].x); box_x.emplace_back(contour[i].x);
box_y.emplace_back(contour[i].y); box_y.emplace_back(contour[i].y);
} }
@ -196,7 +196,7 @@ float DBPostProcessor::PolygonScoreAcc(const std::vector<cv::Point> &contour,
cv::Point *rook_point = new cv::Point[contour.size()]; cv::Point *rook_point = new cv::Point[contour.size()];
for (int i = 0; i < contour.size(); ++i) { for (size_t i = 0; i < contour.size(); ++i) {
rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin); rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
} }
const cv::Point *ppt[1] = {rook_point}; const cv::Point *ppt[1] = {rook_point};
@ -273,7 +273,7 @@ std::vector<std::vector<std::vector<int>>> DBPostProcessor::BoxesFromBitmap(
std::vector<std::vector<std::vector<int>>> boxes; std::vector<std::vector<std::vector<int>>> boxes;
for (int _i = 0; _i < num_contours; _i++) { for (int _i = 0; _i < num_contours; ++_i) {
if (contours[_i].size() <= 2) { if (contours[_i].size() <= 2) {
continue; continue;
} }
@ -315,7 +315,7 @@ std::vector<std::vector<std::vector<int>>> DBPostProcessor::BoxesFromBitmap(
int dest_height = pred.rows; int dest_height = pred.rows;
std::vector<std::vector<int>> intcliparray; std::vector<std::vector<int>> intcliparray;
for (int num_pt = 0; num_pt < 4; num_pt++) { for (int num_pt = 0; num_pt < 4; ++num_pt) {
std::vector<int> a{int(clampf(roundf(cliparray[num_pt][0] / float(width) * std::vector<int> a{int(clampf(roundf(cliparray[num_pt][0] / float(width) *
float(dest_width)), float(dest_width)),
0, float(dest_width))), 0, float(dest_width))),
@ -337,9 +337,9 @@ void DBPostProcessor::FilterTagDetRes(
int oriimg_w = srcimg.cols; int oriimg_w = srcimg.cols;
std::vector<std::vector<std::vector<int>>> root_points; std::vector<std::vector<std::vector<int>>> root_points;
for (int n = 0; n < boxes.size(); n++) { for (size_t n = 0; n < boxes.size(); ++n) {
boxes[n] = OrderPointsClockwise(boxes[n]); boxes[n] = OrderPointsClockwise(boxes[n]);
for (int m = 0; m < boxes[0].size(); m++) { for (size_t m = 0; m < boxes[0].size(); ++m) {
boxes[n][m][0] /= ratio_w; boxes[n][m][0] /= ratio_w;
boxes[n][m][1] /= ratio_h; boxes[n][m][1] /= ratio_h;
@ -348,7 +348,7 @@ void DBPostProcessor::FilterTagDetRes(
} }
} }
for (int n = 0; n < boxes.size(); n++) { for (size_t n = 0; n < boxes.size(); ++n) {
int rect_width, rect_height; int rect_width, rect_height;
rect_width = int(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) + rect_width = int(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) +
pow(boxes[n][0][1] - boxes[n][1][1], 2))); pow(boxes[n][0][1] - boxes[n][1][1], 2)));
@ -389,7 +389,7 @@ void TablePostProcessor::Run(
std::vector<std::vector<std::vector<int>>> &rec_boxes_batch, std::vector<std::vector<std::vector<int>>> &rec_boxes_batch,
const std::vector<int> &width_list, const std::vector<int> &width_list,
const std::vector<int> &height_list) noexcept { const std::vector<int> &height_list) noexcept {
for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) { for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; ++batch_idx) {
// image tags and boxs // image tags and boxs
std::vector<std::string> rec_html_tags; std::vector<std::string> rec_html_tags;
std::vector<std::vector<int>> rec_boxes; std::vector<std::vector<int>> rec_boxes;
@ -400,7 +400,7 @@ void TablePostProcessor::Run(
int char_idx = 0; int char_idx = 0;
// step // step
for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) { for (int step_idx = 0; step_idx < structure_probs_shape[1]; ++step_idx) {
std::string html_tag; std::string html_tag;
std::vector<int> rec_box; std::vector<int> rec_box;
// html tag // html tag
@ -426,7 +426,7 @@ void TablePostProcessor::Run(
// box // box
if (html_tag == "<td>" || html_tag == "<td" || html_tag == "<td></td>") { if (html_tag == "<td>" || html_tag == "<td" || html_tag == "<td></td>") {
for (int point_idx = 0; point_idx < loc_preds_shape[2]; point_idx++) { for (int point_idx = 0; point_idx < loc_preds_shape[2]; ++point_idx) {
step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) * step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
loc_preds_shape[2] + loc_preds_shape[2] +
point_idx; point_idx;
@ -474,16 +474,18 @@ void PicodetPostProcessor::Run(std::vector<StructurePredictResult> &results,
std::vector<std::vector<StructurePredictResult>> bbox_results; std::vector<std::vector<StructurePredictResult>> bbox_results;
bbox_results.resize(this->num_class_); bbox_results.resize(this->num_class_);
for (int i = 0; i < this->fpn_stride_.size(); ++i) { for (size_t i = 0; i < this->fpn_stride_.size(); ++i) {
int feature_h = std::ceil((float)in_h / this->fpn_stride_[i]); const int feature_h = std::ceil((float)in_h / this->fpn_stride_[i]);
int feature_w = std::ceil((float)in_w / this->fpn_stride_[i]); const int feature_w = std::ceil((float)in_w / this->fpn_stride_[i]);
for (int idx = 0; idx < feature_h * feature_w; idx++) { const size_t hxw = feature_h * feature_w;
for (size_t idx = 0; idx < hxw; ++idx) {
// score and label // score and label
float score = 0; float score = 0;
int cur_label = 0; int cur_label = 0;
for (int label = 0; label < this->num_class_; label++) { for (size_t label = 0; label < this->label_list_.size(); ++label) {
if (outs[i][idx * this->num_class_ + label] > score) { float osc = outs[i][idx * this->label_list_.size() + label];
score = outs[i][idx * this->num_class_ + label]; if (osc > score) {
score = osc;
cur_label = label; cur_label = label;
} }
} }
@ -501,11 +503,11 @@ void PicodetPostProcessor::Run(std::vector<StructurePredictResult> &results,
} }
} }
#if 0 #if 0
for (int i = 0; i < bbox_results.size(); ++i) { for (size_t i = 0; i < bbox_results.size(); ++i) {
bool flag = bbox_results[i].size() <= 0; bool flag = bbox_results[i].size() <= 0;
} }
#endif #endif
for (int i = 0; i < bbox_results.size(); ++i) { for (size_t i = 0; i < bbox_results.size(); ++i) {
// bool flag = bbox_results[i].size() <= 0; // bool flag = bbox_results[i].size() <= 0;
if (bbox_results[i].size() <= 0) { if (bbox_results[i].size() <= 0) {
continue; continue;
@ -534,7 +536,7 @@ StructurePredictResult PicodetPostProcessor::disPred2Bbox(
std::vector<float> bbox_pred_i(itemp, itemp + reg_max); std::vector<float> bbox_pred_i(itemp, itemp + reg_max);
std::vector<float> dis_after_sm( std::vector<float> dis_after_sm(
std::move(Utility::activation_function_softmax(bbox_pred_i))); std::move(Utility::activation_function_softmax(bbox_pred_i)));
for (int j = 0; j < reg_max; j++) { for (int j = 0; j < reg_max; ++j) {
dis += j * dis_after_sm[j]; dis += j * dis_after_sm[j];
} }
dis *= stride; dis *= stride;
@ -562,11 +564,11 @@ void PicodetPostProcessor::nms(std::vector<StructurePredictResult> &input_boxes,
}); });
std::vector<int> picked(input_boxes.size(), 1); std::vector<int> picked(input_boxes.size(), 1);
for (int i = 0; i < input_boxes.size(); ++i) { for (size_t i = 0; i < input_boxes.size(); ++i) {
if (picked[i] == 0) { if (picked[i] == 0) {
continue; continue;
} }
for (int j = i + 1; j < input_boxes.size(); ++j) { for (size_t j = i + 1; j < input_boxes.size(); ++j) {
if (picked[j] == 0) { if (picked[j] == 0) {
continue; continue;
} }
@ -577,7 +579,7 @@ void PicodetPostProcessor::nms(std::vector<StructurePredictResult> &input_boxes,
} }
} }
std::vector<StructurePredictResult> input_boxes_nms; std::vector<StructurePredictResult> input_boxes_nms;
for (int i = 0; i < input_boxes.size(); ++i) { for (size_t i = 0; i < input_boxes.size(); ++i) {
if (picked[i] == 1) { if (picked[i] == 1) {
input_boxes_nms.emplace_back(input_boxes[i]); input_boxes_nms.emplace_back(input_boxes[i]);
} }

View File

@ -26,7 +26,7 @@ void Permute::Run(const cv::Mat &im, float *data) noexcept {
} }
void PermuteBatch::Run(const std::vector<cv::Mat> &imgs, float *data) noexcept { void PermuteBatch::Run(const std::vector<cv::Mat> &imgs, float *data) noexcept {
for (int j = 0; j < imgs.size(); j++) { for (size_t j = 0; j < imgs.size(); ++j) {
int rh = imgs[j].rows; int rh = imgs[j].rows;
int rw = imgs[j].cols; int rw = imgs[j].cols;
int rc = imgs[j].channels(); int rc = imgs[j].channels();
@ -47,7 +47,7 @@ void Normalize::Run(cv::Mat &im, const std::vector<float> &mean,
im.convertTo(im, CV_32FC3, e); im.convertTo(im, CV_32FC3, e);
std::vector<cv::Mat> bgr_channels(3); std::vector<cv::Mat> bgr_channels(3);
cv::split(im, bgr_channels); cv::split(im, bgr_channels);
for (auto i = 0; i < bgr_channels.size(); ++i) { for (size_t i = 0; i < bgr_channels.size(); ++i) {
bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i], bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i],
(0.0 - mean[i]) * scale[i]); (0.0 - mean[i]) * scale[i]);
} }

View File

@ -59,7 +59,7 @@ void StructureLayoutRecognizer::Run(const cv::Mat &img,
std::vector<std::vector<float>> out_tensor_list; std::vector<std::vector<float>> out_tensor_list;
std::vector<std::vector<int>> output_shape_list; std::vector<std::vector<int>> output_shape_list;
auto output_names = this->predictor_->GetOutputNames(); auto output_names = this->predictor_->GetOutputNames();
for (int j = 0; j < output_names.size(); j++) { for (size_t j = 0; j < output_names.size(); ++j) {
auto output_tensor = this->predictor_->GetOutputHandle(output_names[j]); auto output_tensor = this->predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape(); std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
@ -79,7 +79,7 @@ void StructureLayoutRecognizer::Run(const cv::Mat &img,
std::vector<int> bbox_num; std::vector<int> bbox_num;
int reg_max = 0; int reg_max = 0;
for (int i = 0; i < out_tensor_list.size(); ++i) { for (size_t i = 0; i < out_tensor_list.size(); ++i) {
if (i == this->post_processor_.fpn_stride_size()) { if (i == this->post_processor_.fpn_stride_size()) {
reg_max = output_shape_list[i][2] / 4; reg_max = output_shape_list[i][2] / 4;
break; break;

View File

@ -33,17 +33,17 @@ void StructureTableRecognizer::Run(
std::chrono::duration<float> postprocess_diff = std::chrono::duration<float> postprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
int img_num = img_list.size(); size_t img_num = img_list.size();
for (int beg_img_no = 0; beg_img_no < img_num; for (size_t beg_img_no = 0; beg_img_no < img_num;
beg_img_no += this->table_batch_num_) { beg_img_no += this->table_batch_num_) {
// preprocess // preprocess
auto preprocess_start = std::chrono::steady_clock::now(); auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = std::min(img_num, beg_img_no + this->table_batch_num_); size_t end_img_no = std::min(img_num, beg_img_no + this->table_batch_num_);
int batch_num = end_img_no - beg_img_no; int batch_num = end_img_no - beg_img_no;
std::vector<cv::Mat> norm_img_batch; std::vector<cv::Mat> norm_img_batch;
std::vector<int> width_list; std::vector<int> width_list;
std::vector<int> height_list; std::vector<int> height_list;
for (int ino = beg_img_no; ino < end_img_no; ino++) { for (size_t ino = beg_img_no; ino < end_img_no; ++ino) {
cv::Mat srcimg; cv::Mat srcimg;
img_list[ino].copyTo(srcimg); img_list[ino].copyTo(srcimg);
cv::Mat resize_img; cv::Mat resize_img;
@ -98,7 +98,7 @@ void StructureTableRecognizer::Run(
predict_shape0, predict_shape1, predict_shape0, predict_shape1,
structure_html_tag_batch, structure_boxes_batch, structure_html_tag_batch, structure_boxes_batch,
width_list, height_list); width_list, height_list);
for (int m = 0; m < predict_shape0[0]; m++) { for (int m = 0; m < predict_shape0[0]; ++m) {
structure_html_tag_batch[m].emplace(structure_html_tag_batch[m].begin(), structure_html_tag_batch[m].emplace(structure_html_tag_batch[m].begin(),
"<table>"); "<table>");

View File

@ -52,9 +52,9 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg,
const std::string &save_path) noexcept { const std::string &save_path) noexcept {
cv::Mat img_vis; cv::Mat img_vis;
srcimg.copyTo(img_vis); srcimg.copyTo(img_vis);
for (int n = 0; n < ocr_result.size(); ++n) { for (size_t n = 0; n < ocr_result.size(); ++n) {
cv::Point rook_points[4]; cv::Point rook_points[4];
for (int m = 0; m < ocr_result[n].box.size(); ++m) { for (size_t m = 0; m < ocr_result[n].box.size(); ++m) {
rook_points[m] = rook_points[m] =
cv::Point(int(ocr_result[n].box[m][0]), int(ocr_result[n].box[m][1])); cv::Point(int(ocr_result[n].box[m][0]), int(ocr_result[n].box[m][1]));
} }
@ -75,10 +75,10 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg,
cv::Mat img_vis; cv::Mat img_vis;
srcimg.copyTo(img_vis); srcimg.copyTo(img_vis);
img_vis = crop_image(img_vis, structure_result.box); img_vis = crop_image(img_vis, structure_result.box);
for (int n = 0; n < structure_result.cell_box.size(); ++n) { for (size_t n = 0; n < structure_result.cell_box.size(); ++n) {
if (structure_result.cell_box[n].size() == 8) { if (structure_result.cell_box[n].size() == 8) {
cv::Point rook_points[4]; cv::Point rook_points[4];
for (int m = 0; m < structure_result.cell_box[n].size(); m += 2) { for (size_t m = 0; m < structure_result.cell_box[n].size(); m += 2) {
rook_points[m / 2] = rook_points[m / 2] =
cv::Point(int(structure_result.cell_box[n][m]), cv::Point(int(structure_result.cell_box[n][m]),
int(structure_result.cell_box[n][m + 1])); int(structure_result.cell_box[n][m + 1]));
@ -151,7 +151,7 @@ Utility::GetRotateCropImage(const cv::Mat &srcimage,
cv::Mat img_crop; cv::Mat img_crop;
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop); image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
for (int i = 0; i < points.size(); ++i) { for (size_t i = 0; i < points.size(); ++i) {
points[i][0] -= left; points[i][0] -= left;
points[i][1] -= top; points[i][1] -= top;
} }
@ -190,14 +190,15 @@ Utility::GetRotateCropImage(const cv::Mat &srcimage,
} }
} }
std::vector<int> Utility::argsort(const std::vector<float> &array) noexcept { std::vector<size_t> Utility::argsort(const std::vector<float> &array) noexcept {
std::vector<int> array_index(array.size(), 0); std::vector<size_t> array_index(array.size(), 0);
for (int i = 0; i < array.size(); ++i) for (size_t i = 0; i < array.size(); ++i)
array_index[i] = i; array_index[i] = i;
std::sort( std::sort(array_index.begin(), array_index.end(),
array_index.begin(), array_index.end(), [&array](size_t pos1, size_t pos2) {
[&array](int pos1, int pos2) { return (array[pos1] < array[pos2]); }); return (array[pos1] < array[pos2]);
});
return array_index; return array_index;
} }
@ -237,35 +238,35 @@ std::string Utility::basename(const std::string &filename) noexcept {
return filename.substr(index + 1, len - index); return filename.substr(index + 1, len - index);
} }
bool Utility::PathExists(const std::string &path) noexcept { bool Utility::PathExists(const char *path) noexcept {
#ifdef _WIN32 #ifdef _WIN32
struct _stat buffer; struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0); return (_stat(path, &buffer) == 0);
#else #else
struct stat buffer; struct stat buffer;
return (stat(path.c_str(), &buffer) == 0); return (stat(path, &buffer) == 0);
#endif // !_WIN32 #endif // !_WIN32
} }
void Utility::CreateDir(const std::string &path) noexcept { void Utility::CreateDir(const char *path) noexcept {
#ifdef _MSC_VER #ifdef _MSC_VER
_mkdir(path.c_str()); _mkdir(path);
#elif defined __MINGW32__ #elif defined __MINGW32__
mkdir(path.c_str()); mkdir(path);
#else #else
mkdir(path.c_str(), 0777); mkdir(path, 0777);
#endif // !_WIN32 #endif // !_WIN32
} }
void Utility::print_result( void Utility::print_result(
const std::vector<OCRPredictResult> &ocr_result) noexcept { const std::vector<OCRPredictResult> &ocr_result) noexcept {
for (int i = 0; i < ocr_result.size(); ++i) { for (size_t i = 0; i < ocr_result.size(); ++i) {
std::cout << i << "\t"; std::cout << i << "\t";
// det // det
const std::vector<std::vector<int>> &boxes = ocr_result[i].box; const std::vector<std::vector<int>> &boxes = ocr_result[i].box;
if (boxes.size() > 0) { if (boxes.size() > 0) {
std::cout << "det boxes: ["; std::cout << "det boxes: [";
for (int n = 0; n < boxes.size(); n++) { for (size_t n = 0; n < boxes.size(); ++n) {
std::cout << '[' << boxes[n][0] << ',' << boxes[n][1] << "]"; std::cout << '[' << boxes[n][0] << ',' << boxes[n][1] << "]";
if (n != boxes.size() - 1) { if (n != boxes.size() - 1) {
std::cout << ','; std::cout << ',';
@ -288,7 +289,7 @@ void Utility::print_result(
} }
} }
cv::Mat Utility::crop_image(cv::Mat &img, cv::Mat Utility::crop_image(const cv::Mat &img,
const std::vector<int> &box) noexcept { const std::vector<int> &box) noexcept {
cv::Mat crop_im = cv::Mat::zeros(box[3] - box[1], box[2] - box[0], 16); cv::Mat crop_im = cv::Mat::zeros(box[3] - box[1], box[2] - box[0], 16);
int crop_x1 = std::max(0, box[0]); int crop_x1 = std::max(0, box[0]);
@ -305,18 +306,18 @@ cv::Mat Utility::crop_image(cv::Mat &img,
return crop_im; return crop_im;
} }
cv::Mat Utility::crop_image(cv::Mat &img, cv::Mat Utility::crop_image(const cv::Mat &img,
const std::vector<float> &box) noexcept { const std::vector<float> &box) noexcept {
std::vector<int> box_int = {(int)box[0], (int)box[1], (int)box[2], std::vector<int> box_int = {(int)box[0], (int)box[1], (int)box[2],
(int)box[3]}; (int)box[3]};
return crop_image(img, box_int); return crop_image(img, box_int);
} }
void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) noexcept { void Utility::sort_boxes(std::vector<OCRPredictResult> &ocr_result) noexcept {
std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box); std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box);
if (ocr_result.size() > 0) { if (ocr_result.size() > 1) {
for (int i = 0; i < ocr_result.size() - 1; ++i) { for (size_t i = 0; i < ocr_result.size() - 1; ++i) {
for (int j = i; j >= 0; j--) { for (size_t j = i; j != size_t(-1); --j) {
if (abs(ocr_result[j + 1].box[0][1] - ocr_result[j].box[0][1]) < 10 && if (abs(ocr_result[j + 1].box[0][1] - ocr_result[j].box[0][1]) < 10 &&
(ocr_result[j + 1].box[0][0] < ocr_result[j].box[0][0])) { (ocr_result[j + 1].box[0][0] < ocr_result[j].box[0][0])) {
std::swap(ocr_result[i], ocr_result[i + 1]); std::swap(ocr_result[i], ocr_result[i + 1]);
@ -367,25 +368,26 @@ float Utility::fast_exp(float x) noexcept {
} }
std::vector<float> std::vector<float>
Utility::activation_function_softmax(std::vector<float> &src) noexcept { Utility::activation_function_softmax(const std::vector<float> &src) noexcept {
int length = src.size(); size_t length = src.size();
std::vector<float> dst; std::vector<float> dst;
dst.resize(length); dst.resize(length);
const float alpha = float(*std::max_element(&src[0], &src[0 + length])); const float alpha = float(*std::max_element(&src[0], &src[length]));
float denominator{0}; float denominator{0};
for (int i = 0; i < length; ++i) { for (size_t i = 0; i < length; ++i) {
dst[i] = fast_exp(src[i] - alpha); dst[i] = fast_exp(src[i] - alpha);
denominator += dst[i]; denominator += dst[i];
} }
for (int i = 0; i < length; ++i) { for (size_t i = 0; i < length; ++i) {
dst[i] /= denominator; dst[i] /= denominator;
} }
return dst; return dst;
} }
float Utility::iou(std::vector<int> &box1, std::vector<int> &box2) noexcept { float Utility::iou(const std::vector<int> &box1,
const std::vector<int> &box2) noexcept {
int area1 = std::max(0, box1[2] - box1[0]) * std::max(0, box1[3] - box1[1]); int area1 = std::max(0, box1[2] - box1[0]) * std::max(0, box1[3] - box1[1]);
int area2 = std::max(0, box2[2] - box2[0]) * std::max(0, box2[3] - box2[1]); int area2 = std::max(0, box2[2] - box2[0]) * std::max(0, box2[3] - box2[1]);
@ -407,8 +409,8 @@ float Utility::iou(std::vector<int> &box1, std::vector<int> &box2) noexcept {
} }
} }
float Utility::iou(std::vector<float> &box1, float Utility::iou(const std::vector<float> &box1,
std::vector<float> &box2) noexcept { const std::vector<float> &box2) noexcept {
float area1 = std::max((float)0.0, box1[2] - box1[0]) * float area1 = std::max((float)0.0, box1[2] - box1[0]) *
std::max((float)0.0, box1[3] - box1[1]); std::max((float)0.0, box1[3] - box1[1]);
float area2 = std::max((float)0.0, box2[2] - box2[0]) * float area2 = std::max((float)0.0, box2[2] - box2[0]) *