mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-12-12 07:17:32 +00:00
CPP: codes view and some minimum fix (compiler warning...) (#14635)
* CPP: minimum fix (compiler warning...) * Codes Review
This commit is contained in:
parent
2c0c4beb06
commit
02f106d668
@ -61,23 +61,29 @@ public:
|
||||
GetRotateCropImage(const cv::Mat &srcimage,
|
||||
const std::vector<std::vector<int>> &box) noexcept;
|
||||
|
||||
static std::vector<int> argsort(const std::vector<float> &array) noexcept;
|
||||
static std::vector<size_t> argsort(const std::vector<float> &array) noexcept;
|
||||
|
||||
static std::string basename(const std::string &filename) noexcept;
|
||||
|
||||
static bool PathExists(const std::string &path) noexcept;
|
||||
static bool PathExists(const char *path) noexcept;
|
||||
static inline bool PathExists(const std::string &path) noexcept {
|
||||
return PathExists(path.c_str());
|
||||
}
|
||||
|
||||
static void CreateDir(const std::string &path) noexcept;
|
||||
static void CreateDir(const char *path) noexcept;
|
||||
static inline void CreateDir(const std::string &path) noexcept {
|
||||
CreateDir(path.c_str());
|
||||
}
|
||||
|
||||
static void
|
||||
print_result(const std::vector<OCRPredictResult> &ocr_result) noexcept;
|
||||
|
||||
static cv::Mat crop_image(cv::Mat &img,
|
||||
static cv::Mat crop_image(const cv::Mat &img,
|
||||
const std::vector<int> &area) noexcept;
|
||||
static cv::Mat crop_image(cv::Mat &img,
|
||||
static cv::Mat crop_image(const cv::Mat &img,
|
||||
const std::vector<float> &area) noexcept;
|
||||
|
||||
static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result) noexcept;
|
||||
static void sort_boxes(std::vector<OCRPredictResult> &ocr_result) noexcept;
|
||||
|
||||
static std::vector<int>
|
||||
xyxyxyxy2xyxy(const std::vector<std::vector<int>> &box) noexcept;
|
||||
@ -85,9 +91,11 @@ public:
|
||||
|
||||
static float fast_exp(float x) noexcept;
|
||||
static std::vector<float>
|
||||
activation_function_softmax(std::vector<float> &src) noexcept;
|
||||
static float iou(std::vector<int> &box1, std::vector<int> &box2) noexcept;
|
||||
static float iou(std::vector<float> &box1, std::vector<float> &box2) noexcept;
|
||||
activation_function_softmax(const std::vector<float> &src) noexcept;
|
||||
static float iou(const std::vector<int> &box1,
|
||||
const std::vector<int> &box2) noexcept;
|
||||
static float iou(const std::vector<float> &box1,
|
||||
const std::vector<float> &box2) noexcept;
|
||||
|
||||
private:
|
||||
static bool comparison_box(const OCRPredictResult &result1,
|
||||
|
||||
@ -2812,7 +2812,7 @@ bool Clipper::FixupIntersectionOrder() noexcept {
|
||||
if (!EdgesAdjacent(*m_IntersectList[i])) {
|
||||
size_t j = i + 1;
|
||||
while (j < cnt && !EdgesAdjacent(*m_IntersectList[j]))
|
||||
j++;
|
||||
++j;
|
||||
if (j == cnt)
|
||||
return false;
|
||||
std::swap(m_IntersectList[i], m_IntersectList[j]);
|
||||
@ -3078,7 +3078,7 @@ void Clipper::BuildResult2(PolyTree &polytree) noexcept {
|
||||
pn->Index = 0;
|
||||
pn->Contour.reserve(cnt);
|
||||
OutPt *op = outRec->Pts->Prev;
|
||||
for (int j = 0; j < cnt; j++) {
|
||||
for (int j = 0; j < cnt; ++j) {
|
||||
pn->Contour.emplace_back(op->Pt);
|
||||
op = op->Prev;
|
||||
}
|
||||
@ -3643,7 +3643,7 @@ void ClipperOffset::AddPath(const Path &path, JoinType joinType,
|
||||
int j = 0, k = 0;
|
||||
for (int i = 1; i <= highI; ++i)
|
||||
if (newNode->Contour[j] != path[i]) {
|
||||
j++;
|
||||
++j;
|
||||
newNode->Contour.emplace_back(path[i]);
|
||||
if (path[i].Y > newNode->Contour[k].Y ||
|
||||
(path[i].Y == newNode->Contour[k].Y &&
|
||||
@ -3826,7 +3826,7 @@ void ClipperOffset::DoOffset(double delta) noexcept {
|
||||
if (len == 1) {
|
||||
if (node.m_jointype == jtRound) {
|
||||
double X = 1.0, Y = 0.0;
|
||||
for (cInt j = 1; j <= steps; j++) {
|
||||
for (cInt j = 1; j <= steps; ++j) {
|
||||
m_destPoly.emplace_back(Round(m_srcPoly[0].X + X * delta),
|
||||
Round(m_srcPoly[0].Y + Y * delta));
|
||||
double X2 = X;
|
||||
|
||||
@ -137,7 +137,7 @@ void structure(std::vector<cv::String> &cv_all_img_names) {
|
||||
std::vector<StructurePredictResult> structure_results = engine.structure(
|
||||
img, FLAGS_layout, FLAGS_table, FLAGS_det && FLAGS_rec);
|
||||
|
||||
for (int j = 0; j < structure_results.size(); j++) {
|
||||
for (size_t j = 0; j < structure_results.size(); ++j) {
|
||||
std::cout << j << "\ttype: " << structure_results[j].type
|
||||
<< ", region: [";
|
||||
std::cout << structure_results[j].box[0] << ","
|
||||
|
||||
@ -40,7 +40,7 @@ void Classifier::Run(const std::vector<cv::Mat> &img_list,
|
||||
int batch_num = end_img_no - beg_img_no;
|
||||
// preprocess
|
||||
std::vector<cv::Mat> norm_img_batch;
|
||||
for (int ino = beg_img_no; ino < end_img_no; ino++) {
|
||||
for (int ino = beg_img_no; ino < end_img_no; ++ino) {
|
||||
cv::Mat srcimg;
|
||||
img_list[ino].copyTo(srcimg);
|
||||
cv::Mat resize_img;
|
||||
@ -87,7 +87,7 @@ void Classifier::Run(const std::vector<cv::Mat> &img_list,
|
||||
|
||||
// postprocess
|
||||
auto postprocess_start = std::chrono::steady_clock::now();
|
||||
for (int batch_idx = 0; batch_idx < predict_shape[0]; batch_idx++) {
|
||||
for (int batch_idx = 0; batch_idx < predict_shape[0]; ++batch_idx) {
|
||||
int label = int(
|
||||
Utility::argmax(&predict_batch[batch_idx * predict_shape[1]],
|
||||
&predict_batch[(batch_idx + 1) * predict_shape[1]]));
|
||||
|
||||
@ -32,22 +32,22 @@ void CRNNRecognizer::Run(const std::vector<cv::Mat> &img_list,
|
||||
std::chrono::duration<float> postprocess_diff =
|
||||
std::chrono::duration<float>::zero();
|
||||
|
||||
int img_num = img_list.size();
|
||||
size_t img_num = img_list.size();
|
||||
std::vector<float> width_list;
|
||||
for (int i = 0; i < img_num; ++i) {
|
||||
for (size_t i = 0; i < img_num; ++i) {
|
||||
width_list.emplace_back(float(img_list[i].cols) / img_list[i].rows);
|
||||
}
|
||||
std::vector<int> indices = std::move(Utility::argsort(width_list));
|
||||
std::vector<size_t> indices = std::move(Utility::argsort(width_list));
|
||||
|
||||
for (int beg_img_no = 0; beg_img_no < img_num;
|
||||
for (size_t beg_img_no = 0; beg_img_no < img_num;
|
||||
beg_img_no += this->rec_batch_num_) {
|
||||
auto preprocess_start = std::chrono::steady_clock::now();
|
||||
int end_img_no = std::min(img_num, beg_img_no + this->rec_batch_num_);
|
||||
size_t end_img_no = std::min(img_num, beg_img_no + this->rec_batch_num_);
|
||||
int batch_num = end_img_no - beg_img_no;
|
||||
int imgH = this->rec_image_shape_[1];
|
||||
int imgW = this->rec_image_shape_[2];
|
||||
float max_wh_ratio = imgW * 1.0 / imgH;
|
||||
for (int ino = beg_img_no; ino < end_img_no; ino++) {
|
||||
for (size_t ino = beg_img_no; ino < end_img_no; ++ino) {
|
||||
int h = img_list[indices[ino]].rows;
|
||||
int w = img_list[indices[ino]].cols;
|
||||
float wh_ratio = w * 1.0 / h;
|
||||
@ -56,7 +56,7 @@ void CRNNRecognizer::Run(const std::vector<cv::Mat> &img_list,
|
||||
|
||||
int batch_width = imgW;
|
||||
std::vector<cv::Mat> norm_img_batch;
|
||||
for (int ino = beg_img_no; ino < end_img_no; ino++) {
|
||||
for (size_t ino = beg_img_no; ino < end_img_no; ++ino) {
|
||||
cv::Mat srcimg;
|
||||
img_list[indices[ino]].copyTo(srcimg);
|
||||
cv::Mat resize_img;
|
||||
@ -85,8 +85,8 @@ void CRNNRecognizer::Run(const std::vector<cv::Mat> &img_list,
|
||||
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
|
||||
auto predict_shape = output_t->shape();
|
||||
|
||||
int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1,
|
||||
std::multiplies<int>());
|
||||
size_t out_num = std::accumulate(predict_shape.begin(), predict_shape.end(),
|
||||
1, std::multiplies<int>());
|
||||
predict_batch.resize(out_num);
|
||||
// predict_batch is the result of Last FC with softmax
|
||||
output_t->CopyToCpu(predict_batch.data());
|
||||
@ -94,7 +94,7 @@ void CRNNRecognizer::Run(const std::vector<cv::Mat> &img_list,
|
||||
inference_diff += inference_end - inference_start;
|
||||
// ctc decode
|
||||
auto postprocess_start = std::chrono::steady_clock::now();
|
||||
for (int m = 0; m < predict_shape[0]; m++) {
|
||||
for (int m = 0; m < predict_shape[0]; ++m) {
|
||||
std::string str_res;
|
||||
int argmax_idx;
|
||||
int last_index = 0;
|
||||
@ -102,7 +102,7 @@ void CRNNRecognizer::Run(const std::vector<cv::Mat> &img_list,
|
||||
int count = 0;
|
||||
float max_value = 0.0f;
|
||||
|
||||
for (int n = 0; n < predict_shape[1]; n++) {
|
||||
for (int n = 0; n < predict_shape[1]; ++n) {
|
||||
// get idx
|
||||
argmax_idx = int(Utility::argmax(
|
||||
&predict_batch[(m * predict_shape[1] + n) * predict_shape[2]],
|
||||
|
||||
@ -65,7 +65,7 @@ PPOCR::ocr(const std::vector<cv::Mat> &img_list, bool det, bool rec,
|
||||
ocr_result.resize(img_list.size());
|
||||
if (cls && this->pri_->classifier_) {
|
||||
this->cls(img_list, ocr_result);
|
||||
for (int i = 0; i < img_list.size(); ++i) {
|
||||
for (size_t i = 0; i < img_list.size(); ++i) {
|
||||
if (ocr_result[i].cls_label % 2 == 1 &&
|
||||
ocr_result[i].cls_score > this->pri_->classifier_->cls_thresh) {
|
||||
cv::rotate(img_list[i], img_list[i], 1);
|
||||
@ -75,11 +75,11 @@ PPOCR::ocr(const std::vector<cv::Mat> &img_list, bool det, bool rec,
|
||||
if (rec) {
|
||||
this->rec(img_list, ocr_result);
|
||||
}
|
||||
for (int i = 0; i < ocr_result.size(); ++i) {
|
||||
for (size_t i = 0; i < ocr_result.size(); ++i) {
|
||||
ocr_results.emplace_back(1, std::move(ocr_result[i]));
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < img_list.size(); ++i) {
|
||||
for (size_t i = 0; i < img_list.size(); ++i) {
|
||||
std::vector<OCRPredictResult> ocr_result =
|
||||
this->ocr(img_list[i], true, rec, cls);
|
||||
ocr_results.emplace_back(std::move(ocr_result));
|
||||
@ -96,14 +96,14 @@ std::vector<OCRPredictResult> PPOCR::ocr(const cv::Mat &img, bool det, bool rec,
|
||||
this->det(img, ocr_result);
|
||||
// crop image
|
||||
std::vector<cv::Mat> img_list;
|
||||
for (int j = 0; j < ocr_result.size(); j++) {
|
||||
for (size_t j = 0; j < ocr_result.size(); ++j) {
|
||||
cv::Mat crop_img = Utility::GetRotateCropImage(img, ocr_result[j].box);
|
||||
img_list.emplace_back(std::move(crop_img));
|
||||
}
|
||||
// cls
|
||||
if (cls && this->pri_->classifier_) {
|
||||
this->cls(img_list, ocr_result);
|
||||
for (int i = 0; i < img_list.size(); ++i) {
|
||||
for (size_t i = 0; i < img_list.size(); ++i) {
|
||||
if (ocr_result[i].cls_label % 2 == 1 &&
|
||||
ocr_result[i].cls_score > this->pri_->classifier_->cls_thresh) {
|
||||
cv::rotate(img_list[i], img_list[i], 1);
|
||||
@ -124,13 +124,13 @@ void PPOCR::det(const cv::Mat &img,
|
||||
|
||||
this->pri_->detector_->Run(img, boxes, det_times);
|
||||
|
||||
for (int i = 0; i < boxes.size(); ++i) {
|
||||
for (size_t i = 0; i < boxes.size(); ++i) {
|
||||
OCRPredictResult res;
|
||||
res.box = std::move(boxes[i]);
|
||||
ocr_results.emplace_back(std::move(res));
|
||||
}
|
||||
// sort boex from top to bottom, from left to right
|
||||
Utility::sorted_boxes(ocr_results);
|
||||
Utility::sort_boxes(ocr_results);
|
||||
this->time_info_det[0] += det_times[0];
|
||||
this->time_info_det[1] += det_times[1];
|
||||
this->time_info_det[2] += det_times[2];
|
||||
@ -143,7 +143,7 @@ void PPOCR::rec(const std::vector<cv::Mat> &img_list,
|
||||
std::vector<double> rec_times;
|
||||
this->pri_->recognizer_->Run(img_list, rec_texts, rec_text_scores, rec_times);
|
||||
// output rec results
|
||||
for (int i = 0; i < rec_texts.size(); ++i) {
|
||||
for (size_t i = 0; i < rec_texts.size(); ++i) {
|
||||
ocr_results[i].text = std::move(rec_texts[i]);
|
||||
ocr_results[i].score = rec_text_scores[i];
|
||||
}
|
||||
@ -159,7 +159,7 @@ void PPOCR::cls(const std::vector<cv::Mat> &img_list,
|
||||
std::vector<double> cls_times;
|
||||
this->pri_->classifier_->Run(img_list, cls_labels, cls_scores, cls_times);
|
||||
// output cls results
|
||||
for (int i = 0; i < cls_labels.size(); ++i) {
|
||||
for (size_t i = 0; i < cls_labels.size(); ++i) {
|
||||
ocr_results[i].cls_label = cls_labels[i];
|
||||
ocr_results[i].cls_score = cls_scores[i];
|
||||
}
|
||||
|
||||
@ -64,7 +64,7 @@ PaddleStructure::structure(const cv::Mat &srcimg, bool layout, bool table,
|
||||
structure_results.emplace_back(std::move(res));
|
||||
}
|
||||
cv::Mat roi_img;
|
||||
for (int i = 0; i < structure_results.size(); ++i) {
|
||||
for (size_t i = 0; i < structure_results.size(); ++i) {
|
||||
// crop image
|
||||
roi_img = std::move(Utility::crop_image(img, structure_results[i].box));
|
||||
if (structure_results[i].type == "table" && table) {
|
||||
@ -108,13 +108,13 @@ void PaddleStructure::table(const cv::Mat &img,
|
||||
std::vector<OCRPredictResult> ocr_result;
|
||||
int expand_pixel = 3;
|
||||
|
||||
for (int i = 0; i < img_list.size(); ++i) {
|
||||
for (size_t i = 0; i < img_list.size(); ++i) {
|
||||
// det
|
||||
this->det(img_list[i], ocr_result);
|
||||
// crop image
|
||||
std::vector<cv::Mat> rec_img_list;
|
||||
std::vector<int> ocr_box;
|
||||
for (int j = 0; j < ocr_result.size(); j++) {
|
||||
for (size_t j = 0; j < ocr_result.size(); ++j) {
|
||||
ocr_box = std::move(Utility::xyxyxyxy2xyxy(ocr_result[j].box));
|
||||
ocr_box[0] = std::max(0, ocr_box[0] - expand_pixel);
|
||||
ocr_box[1] = std::max(0, ocr_box[1] - expand_pixel),
|
||||
@ -144,7 +144,7 @@ std::string PaddleStructure::rebuild_table(
|
||||
|
||||
std::vector<int> ocr_box;
|
||||
std::vector<int> structure_box;
|
||||
for (int i = 0; i < ocr_result.size(); ++i) {
|
||||
for (size_t i = 0; i < ocr_result.size(); ++i) {
|
||||
ocr_box = std::move(Utility::xyxyxyxy2xyxy(ocr_result[i].box));
|
||||
ocr_box[0] -= 1;
|
||||
ocr_box[1] -= 1;
|
||||
@ -152,8 +152,8 @@ std::string PaddleStructure::rebuild_table(
|
||||
ocr_box[3] += 1;
|
||||
std::vector<std::vector<float>> dis_list(structure_boxes.size(),
|
||||
std::vector<float>(3, 100000.0));
|
||||
for (int j = 0; j < structure_boxes.size(); j++) {
|
||||
if (structure_boxes[i].size() == 8) {
|
||||
for (size_t j = 0; j < structure_boxes.size(); ++j) {
|
||||
if (structure_boxes[j].size() == 8) {
|
||||
structure_box = std::move(Utility::xyxyxyxy2xyxy(structure_boxes[j]));
|
||||
} else {
|
||||
structure_box = structure_boxes[j];
|
||||
@ -171,7 +171,7 @@ std::string PaddleStructure::rebuild_table(
|
||||
// get pred html
|
||||
std::string html_str = "";
|
||||
int td_tag_idx = 0;
|
||||
for (int i = 0; i < structure_html_tags.size(); ++i) {
|
||||
for (size_t i = 0; i < structure_html_tags.size(); ++i) {
|
||||
if (structure_html_tags[i].find("</td>") != std::string::npos) {
|
||||
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
|
||||
html_str += "<td>";
|
||||
@ -183,7 +183,7 @@ std::string PaddleStructure::rebuild_table(
|
||||
b_with = true;
|
||||
html_str += "<b>";
|
||||
}
|
||||
for (int j = 0; j < matched[td_tag_idx].size(); j++) {
|
||||
for (size_t j = 0; j < matched[td_tag_idx].size(); ++j) {
|
||||
std::string content = matched[td_tag_idx][j];
|
||||
if (matched[td_tag_idx].size() > 1) {
|
||||
// remove blank, <b> and </b>
|
||||
|
||||
@ -57,8 +57,8 @@ DBPostProcessor::UnClip(const std::vector<std::vector<float>> &box,
|
||||
|
||||
std::vector<cv::Point2f> points;
|
||||
|
||||
for (int j = 0; j < soln.size(); j++) {
|
||||
for (int i = 0; i < soln[soln.size() - 1].size(); ++i) {
|
||||
for (size_t j = 0; j < soln.size(); ++j) {
|
||||
for (size_t i = 0; i < soln[soln.size() - 1].size(); ++i) {
|
||||
points.emplace_back(soln[j][i].X, soln[j][i].Y);
|
||||
}
|
||||
}
|
||||
@ -173,7 +173,7 @@ float DBPostProcessor::PolygonScoreAcc(const std::vector<cv::Point> &contour,
|
||||
int height = pred.rows;
|
||||
std::vector<float> box_x;
|
||||
std::vector<float> box_y;
|
||||
for (int i = 0; i < contour.size(); ++i) {
|
||||
for (size_t i = 0; i < contour.size(); ++i) {
|
||||
box_x.emplace_back(contour[i].x);
|
||||
box_y.emplace_back(contour[i].y);
|
||||
}
|
||||
@ -196,7 +196,7 @@ float DBPostProcessor::PolygonScoreAcc(const std::vector<cv::Point> &contour,
|
||||
|
||||
cv::Point *rook_point = new cv::Point[contour.size()];
|
||||
|
||||
for (int i = 0; i < contour.size(); ++i) {
|
||||
for (size_t i = 0; i < contour.size(); ++i) {
|
||||
rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
|
||||
}
|
||||
const cv::Point *ppt[1] = {rook_point};
|
||||
@ -273,7 +273,7 @@ std::vector<std::vector<std::vector<int>>> DBPostProcessor::BoxesFromBitmap(
|
||||
|
||||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
|
||||
for (int _i = 0; _i < num_contours; _i++) {
|
||||
for (int _i = 0; _i < num_contours; ++_i) {
|
||||
if (contours[_i].size() <= 2) {
|
||||
continue;
|
||||
}
|
||||
@ -315,7 +315,7 @@ std::vector<std::vector<std::vector<int>>> DBPostProcessor::BoxesFromBitmap(
|
||||
int dest_height = pred.rows;
|
||||
std::vector<std::vector<int>> intcliparray;
|
||||
|
||||
for (int num_pt = 0; num_pt < 4; num_pt++) {
|
||||
for (int num_pt = 0; num_pt < 4; ++num_pt) {
|
||||
std::vector<int> a{int(clampf(roundf(cliparray[num_pt][0] / float(width) *
|
||||
float(dest_width)),
|
||||
0, float(dest_width))),
|
||||
@ -337,9 +337,9 @@ void DBPostProcessor::FilterTagDetRes(
|
||||
int oriimg_w = srcimg.cols;
|
||||
|
||||
std::vector<std::vector<std::vector<int>>> root_points;
|
||||
for (int n = 0; n < boxes.size(); n++) {
|
||||
for (size_t n = 0; n < boxes.size(); ++n) {
|
||||
boxes[n] = OrderPointsClockwise(boxes[n]);
|
||||
for (int m = 0; m < boxes[0].size(); m++) {
|
||||
for (size_t m = 0; m < boxes[0].size(); ++m) {
|
||||
boxes[n][m][0] /= ratio_w;
|
||||
boxes[n][m][1] /= ratio_h;
|
||||
|
||||
@ -348,7 +348,7 @@ void DBPostProcessor::FilterTagDetRes(
|
||||
}
|
||||
}
|
||||
|
||||
for (int n = 0; n < boxes.size(); n++) {
|
||||
for (size_t n = 0; n < boxes.size(); ++n) {
|
||||
int rect_width, rect_height;
|
||||
rect_width = int(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) +
|
||||
pow(boxes[n][0][1] - boxes[n][1][1], 2)));
|
||||
@ -389,7 +389,7 @@ void TablePostProcessor::Run(
|
||||
std::vector<std::vector<std::vector<int>>> &rec_boxes_batch,
|
||||
const std::vector<int> &width_list,
|
||||
const std::vector<int> &height_list) noexcept {
|
||||
for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) {
|
||||
for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; ++batch_idx) {
|
||||
// image tags and boxs
|
||||
std::vector<std::string> rec_html_tags;
|
||||
std::vector<std::vector<int>> rec_boxes;
|
||||
@ -400,7 +400,7 @@ void TablePostProcessor::Run(
|
||||
int char_idx = 0;
|
||||
|
||||
// step
|
||||
for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) {
|
||||
for (int step_idx = 0; step_idx < structure_probs_shape[1]; ++step_idx) {
|
||||
std::string html_tag;
|
||||
std::vector<int> rec_box;
|
||||
// html tag
|
||||
@ -426,7 +426,7 @@ void TablePostProcessor::Run(
|
||||
|
||||
// box
|
||||
if (html_tag == "<td>" || html_tag == "<td" || html_tag == "<td></td>") {
|
||||
for (int point_idx = 0; point_idx < loc_preds_shape[2]; point_idx++) {
|
||||
for (int point_idx = 0; point_idx < loc_preds_shape[2]; ++point_idx) {
|
||||
step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
|
||||
loc_preds_shape[2] +
|
||||
point_idx;
|
||||
@ -474,16 +474,18 @@ void PicodetPostProcessor::Run(std::vector<StructurePredictResult> &results,
|
||||
|
||||
std::vector<std::vector<StructurePredictResult>> bbox_results;
|
||||
bbox_results.resize(this->num_class_);
|
||||
for (int i = 0; i < this->fpn_stride_.size(); ++i) {
|
||||
int feature_h = std::ceil((float)in_h / this->fpn_stride_[i]);
|
||||
int feature_w = std::ceil((float)in_w / this->fpn_stride_[i]);
|
||||
for (int idx = 0; idx < feature_h * feature_w; idx++) {
|
||||
for (size_t i = 0; i < this->fpn_stride_.size(); ++i) {
|
||||
const int feature_h = std::ceil((float)in_h / this->fpn_stride_[i]);
|
||||
const int feature_w = std::ceil((float)in_w / this->fpn_stride_[i]);
|
||||
const size_t hxw = feature_h * feature_w;
|
||||
for (size_t idx = 0; idx < hxw; ++idx) {
|
||||
// score and label
|
||||
float score = 0;
|
||||
int cur_label = 0;
|
||||
for (int label = 0; label < this->num_class_; label++) {
|
||||
if (outs[i][idx * this->num_class_ + label] > score) {
|
||||
score = outs[i][idx * this->num_class_ + label];
|
||||
for (size_t label = 0; label < this->label_list_.size(); ++label) {
|
||||
float osc = outs[i][idx * this->label_list_.size() + label];
|
||||
if (osc > score) {
|
||||
score = osc;
|
||||
cur_label = label;
|
||||
}
|
||||
}
|
||||
@ -501,11 +503,11 @@ void PicodetPostProcessor::Run(std::vector<StructurePredictResult> &results,
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
for (int i = 0; i < bbox_results.size(); ++i) {
|
||||
for (size_t i = 0; i < bbox_results.size(); ++i) {
|
||||
bool flag = bbox_results[i].size() <= 0;
|
||||
}
|
||||
#endif
|
||||
for (int i = 0; i < bbox_results.size(); ++i) {
|
||||
for (size_t i = 0; i < bbox_results.size(); ++i) {
|
||||
// bool flag = bbox_results[i].size() <= 0;
|
||||
if (bbox_results[i].size() <= 0) {
|
||||
continue;
|
||||
@ -534,7 +536,7 @@ StructurePredictResult PicodetPostProcessor::disPred2Bbox(
|
||||
std::vector<float> bbox_pred_i(itemp, itemp + reg_max);
|
||||
std::vector<float> dis_after_sm(
|
||||
std::move(Utility::activation_function_softmax(bbox_pred_i)));
|
||||
for (int j = 0; j < reg_max; j++) {
|
||||
for (int j = 0; j < reg_max; ++j) {
|
||||
dis += j * dis_after_sm[j];
|
||||
}
|
||||
dis *= stride;
|
||||
@ -562,11 +564,11 @@ void PicodetPostProcessor::nms(std::vector<StructurePredictResult> &input_boxes,
|
||||
});
|
||||
std::vector<int> picked(input_boxes.size(), 1);
|
||||
|
||||
for (int i = 0; i < input_boxes.size(); ++i) {
|
||||
for (size_t i = 0; i < input_boxes.size(); ++i) {
|
||||
if (picked[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
for (int j = i + 1; j < input_boxes.size(); ++j) {
|
||||
for (size_t j = i + 1; j < input_boxes.size(); ++j) {
|
||||
if (picked[j] == 0) {
|
||||
continue;
|
||||
}
|
||||
@ -577,7 +579,7 @@ void PicodetPostProcessor::nms(std::vector<StructurePredictResult> &input_boxes,
|
||||
}
|
||||
}
|
||||
std::vector<StructurePredictResult> input_boxes_nms;
|
||||
for (int i = 0; i < input_boxes.size(); ++i) {
|
||||
for (size_t i = 0; i < input_boxes.size(); ++i) {
|
||||
if (picked[i] == 1) {
|
||||
input_boxes_nms.emplace_back(input_boxes[i]);
|
||||
}
|
||||
|
||||
@ -26,7 +26,7 @@ void Permute::Run(const cv::Mat &im, float *data) noexcept {
|
||||
}
|
||||
|
||||
void PermuteBatch::Run(const std::vector<cv::Mat> &imgs, float *data) noexcept {
|
||||
for (int j = 0; j < imgs.size(); j++) {
|
||||
for (size_t j = 0; j < imgs.size(); ++j) {
|
||||
int rh = imgs[j].rows;
|
||||
int rw = imgs[j].cols;
|
||||
int rc = imgs[j].channels();
|
||||
@ -47,7 +47,7 @@ void Normalize::Run(cv::Mat &im, const std::vector<float> &mean,
|
||||
im.convertTo(im, CV_32FC3, e);
|
||||
std::vector<cv::Mat> bgr_channels(3);
|
||||
cv::split(im, bgr_channels);
|
||||
for (auto i = 0; i < bgr_channels.size(); ++i) {
|
||||
for (size_t i = 0; i < bgr_channels.size(); ++i) {
|
||||
bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i],
|
||||
(0.0 - mean[i]) * scale[i]);
|
||||
}
|
||||
|
||||
@ -59,7 +59,7 @@ void StructureLayoutRecognizer::Run(const cv::Mat &img,
|
||||
std::vector<std::vector<float>> out_tensor_list;
|
||||
std::vector<std::vector<int>> output_shape_list;
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
for (int j = 0; j < output_names.size(); j++) {
|
||||
for (size_t j = 0; j < output_names.size(); ++j) {
|
||||
auto output_tensor = this->predictor_->GetOutputHandle(output_names[j]);
|
||||
std::vector<int> output_shape = output_tensor->shape();
|
||||
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
|
||||
@ -79,7 +79,7 @@ void StructureLayoutRecognizer::Run(const cv::Mat &img,
|
||||
|
||||
std::vector<int> bbox_num;
|
||||
int reg_max = 0;
|
||||
for (int i = 0; i < out_tensor_list.size(); ++i) {
|
||||
for (size_t i = 0; i < out_tensor_list.size(); ++i) {
|
||||
if (i == this->post_processor_.fpn_stride_size()) {
|
||||
reg_max = output_shape_list[i][2] / 4;
|
||||
break;
|
||||
|
||||
@ -33,17 +33,17 @@ void StructureTableRecognizer::Run(
|
||||
std::chrono::duration<float> postprocess_diff =
|
||||
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
|
||||
|
||||
int img_num = img_list.size();
|
||||
for (int beg_img_no = 0; beg_img_no < img_num;
|
||||
size_t img_num = img_list.size();
|
||||
for (size_t beg_img_no = 0; beg_img_no < img_num;
|
||||
beg_img_no += this->table_batch_num_) {
|
||||
// preprocess
|
||||
auto preprocess_start = std::chrono::steady_clock::now();
|
||||
int end_img_no = std::min(img_num, beg_img_no + this->table_batch_num_);
|
||||
size_t end_img_no = std::min(img_num, beg_img_no + this->table_batch_num_);
|
||||
int batch_num = end_img_no - beg_img_no;
|
||||
std::vector<cv::Mat> norm_img_batch;
|
||||
std::vector<int> width_list;
|
||||
std::vector<int> height_list;
|
||||
for (int ino = beg_img_no; ino < end_img_no; ino++) {
|
||||
for (size_t ino = beg_img_no; ino < end_img_no; ++ino) {
|
||||
cv::Mat srcimg;
|
||||
img_list[ino].copyTo(srcimg);
|
||||
cv::Mat resize_img;
|
||||
@ -98,7 +98,7 @@ void StructureTableRecognizer::Run(
|
||||
predict_shape0, predict_shape1,
|
||||
structure_html_tag_batch, structure_boxes_batch,
|
||||
width_list, height_list);
|
||||
for (int m = 0; m < predict_shape0[0]; m++) {
|
||||
for (int m = 0; m < predict_shape0[0]; ++m) {
|
||||
|
||||
structure_html_tag_batch[m].emplace(structure_html_tag_batch[m].begin(),
|
||||
"<table>");
|
||||
|
||||
@ -52,9 +52,9 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg,
|
||||
const std::string &save_path) noexcept {
|
||||
cv::Mat img_vis;
|
||||
srcimg.copyTo(img_vis);
|
||||
for (int n = 0; n < ocr_result.size(); ++n) {
|
||||
for (size_t n = 0; n < ocr_result.size(); ++n) {
|
||||
cv::Point rook_points[4];
|
||||
for (int m = 0; m < ocr_result[n].box.size(); ++m) {
|
||||
for (size_t m = 0; m < ocr_result[n].box.size(); ++m) {
|
||||
rook_points[m] =
|
||||
cv::Point(int(ocr_result[n].box[m][0]), int(ocr_result[n].box[m][1]));
|
||||
}
|
||||
@ -75,10 +75,10 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg,
|
||||
cv::Mat img_vis;
|
||||
srcimg.copyTo(img_vis);
|
||||
img_vis = crop_image(img_vis, structure_result.box);
|
||||
for (int n = 0; n < structure_result.cell_box.size(); ++n) {
|
||||
for (size_t n = 0; n < structure_result.cell_box.size(); ++n) {
|
||||
if (structure_result.cell_box[n].size() == 8) {
|
||||
cv::Point rook_points[4];
|
||||
for (int m = 0; m < structure_result.cell_box[n].size(); m += 2) {
|
||||
for (size_t m = 0; m < structure_result.cell_box[n].size(); m += 2) {
|
||||
rook_points[m / 2] =
|
||||
cv::Point(int(structure_result.cell_box[n][m]),
|
||||
int(structure_result.cell_box[n][m + 1]));
|
||||
@ -151,7 +151,7 @@ Utility::GetRotateCropImage(const cv::Mat &srcimage,
|
||||
cv::Mat img_crop;
|
||||
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
|
||||
|
||||
for (int i = 0; i < points.size(); ++i) {
|
||||
for (size_t i = 0; i < points.size(); ++i) {
|
||||
points[i][0] -= left;
|
||||
points[i][1] -= top;
|
||||
}
|
||||
@ -190,14 +190,15 @@ Utility::GetRotateCropImage(const cv::Mat &srcimage,
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> Utility::argsort(const std::vector<float> &array) noexcept {
|
||||
std::vector<int> array_index(array.size(), 0);
|
||||
for (int i = 0; i < array.size(); ++i)
|
||||
std::vector<size_t> Utility::argsort(const std::vector<float> &array) noexcept {
|
||||
std::vector<size_t> array_index(array.size(), 0);
|
||||
for (size_t i = 0; i < array.size(); ++i)
|
||||
array_index[i] = i;
|
||||
|
||||
std::sort(
|
||||
array_index.begin(), array_index.end(),
|
||||
[&array](int pos1, int pos2) { return (array[pos1] < array[pos2]); });
|
||||
std::sort(array_index.begin(), array_index.end(),
|
||||
[&array](size_t pos1, size_t pos2) {
|
||||
return (array[pos1] < array[pos2]);
|
||||
});
|
||||
|
||||
return array_index;
|
||||
}
|
||||
@ -237,35 +238,35 @@ std::string Utility::basename(const std::string &filename) noexcept {
|
||||
return filename.substr(index + 1, len - index);
|
||||
}
|
||||
|
||||
bool Utility::PathExists(const std::string &path) noexcept {
|
||||
bool Utility::PathExists(const char *path) noexcept {
|
||||
#ifdef _WIN32
|
||||
struct _stat buffer;
|
||||
return (_stat(path.c_str(), &buffer) == 0);
|
||||
return (_stat(path, &buffer) == 0);
|
||||
#else
|
||||
struct stat buffer;
|
||||
return (stat(path.c_str(), &buffer) == 0);
|
||||
return (stat(path, &buffer) == 0);
|
||||
#endif // !_WIN32
|
||||
}
|
||||
|
||||
void Utility::CreateDir(const std::string &path) noexcept {
|
||||
void Utility::CreateDir(const char *path) noexcept {
|
||||
#ifdef _MSC_VER
|
||||
_mkdir(path.c_str());
|
||||
_mkdir(path);
|
||||
#elif defined __MINGW32__
|
||||
mkdir(path.c_str());
|
||||
mkdir(path);
|
||||
#else
|
||||
mkdir(path.c_str(), 0777);
|
||||
mkdir(path, 0777);
|
||||
#endif // !_WIN32
|
||||
}
|
||||
|
||||
void Utility::print_result(
|
||||
const std::vector<OCRPredictResult> &ocr_result) noexcept {
|
||||
for (int i = 0; i < ocr_result.size(); ++i) {
|
||||
for (size_t i = 0; i < ocr_result.size(); ++i) {
|
||||
std::cout << i << "\t";
|
||||
// det
|
||||
const std::vector<std::vector<int>> &boxes = ocr_result[i].box;
|
||||
if (boxes.size() > 0) {
|
||||
std::cout << "det boxes: [";
|
||||
for (int n = 0; n < boxes.size(); n++) {
|
||||
for (size_t n = 0; n < boxes.size(); ++n) {
|
||||
std::cout << '[' << boxes[n][0] << ',' << boxes[n][1] << "]";
|
||||
if (n != boxes.size() - 1) {
|
||||
std::cout << ',';
|
||||
@ -288,7 +289,7 @@ void Utility::print_result(
|
||||
}
|
||||
}
|
||||
|
||||
cv::Mat Utility::crop_image(cv::Mat &img,
|
||||
cv::Mat Utility::crop_image(const cv::Mat &img,
|
||||
const std::vector<int> &box) noexcept {
|
||||
cv::Mat crop_im = cv::Mat::zeros(box[3] - box[1], box[2] - box[0], 16);
|
||||
int crop_x1 = std::max(0, box[0]);
|
||||
@ -305,18 +306,18 @@ cv::Mat Utility::crop_image(cv::Mat &img,
|
||||
return crop_im;
|
||||
}
|
||||
|
||||
cv::Mat Utility::crop_image(cv::Mat &img,
|
||||
cv::Mat Utility::crop_image(const cv::Mat &img,
|
||||
const std::vector<float> &box) noexcept {
|
||||
std::vector<int> box_int = {(int)box[0], (int)box[1], (int)box[2],
|
||||
(int)box[3]};
|
||||
return crop_image(img, box_int);
|
||||
}
|
||||
|
||||
void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) noexcept {
|
||||
void Utility::sort_boxes(std::vector<OCRPredictResult> &ocr_result) noexcept {
|
||||
std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box);
|
||||
if (ocr_result.size() > 0) {
|
||||
for (int i = 0; i < ocr_result.size() - 1; ++i) {
|
||||
for (int j = i; j >= 0; j--) {
|
||||
if (ocr_result.size() > 1) {
|
||||
for (size_t i = 0; i < ocr_result.size() - 1; ++i) {
|
||||
for (size_t j = i; j != size_t(-1); --j) {
|
||||
if (abs(ocr_result[j + 1].box[0][1] - ocr_result[j].box[0][1]) < 10 &&
|
||||
(ocr_result[j + 1].box[0][0] < ocr_result[j].box[0][0])) {
|
||||
std::swap(ocr_result[i], ocr_result[i + 1]);
|
||||
@ -367,25 +368,26 @@ float Utility::fast_exp(float x) noexcept {
|
||||
}
|
||||
|
||||
std::vector<float>
|
||||
Utility::activation_function_softmax(std::vector<float> &src) noexcept {
|
||||
int length = src.size();
|
||||
Utility::activation_function_softmax(const std::vector<float> &src) noexcept {
|
||||
size_t length = src.size();
|
||||
std::vector<float> dst;
|
||||
dst.resize(length);
|
||||
const float alpha = float(*std::max_element(&src[0], &src[0 + length]));
|
||||
const float alpha = float(*std::max_element(&src[0], &src[length]));
|
||||
float denominator{0};
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
dst[i] = fast_exp(src[i] - alpha);
|
||||
denominator += dst[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
dst[i] /= denominator;
|
||||
}
|
||||
return dst;
|
||||
}
|
||||
|
||||
float Utility::iou(std::vector<int> &box1, std::vector<int> &box2) noexcept {
|
||||
float Utility::iou(const std::vector<int> &box1,
|
||||
const std::vector<int> &box2) noexcept {
|
||||
int area1 = std::max(0, box1[2] - box1[0]) * std::max(0, box1[3] - box1[1]);
|
||||
int area2 = std::max(0, box2[2] - box2[0]) * std::max(0, box2[3] - box2[1]);
|
||||
|
||||
@ -407,8 +409,8 @@ float Utility::iou(std::vector<int> &box1, std::vector<int> &box2) noexcept {
|
||||
}
|
||||
}
|
||||
|
||||
float Utility::iou(std::vector<float> &box1,
|
||||
std::vector<float> &box2) noexcept {
|
||||
float Utility::iou(const std::vector<float> &box1,
|
||||
const std::vector<float> &box2) noexcept {
|
||||
float area1 = std::max((float)0.0, box1[2] - box1[0]) *
|
||||
std::max((float)0.0, box1[3] - box1[1]);
|
||||
float area2 = std::max((float)0.0, box2[2] - box2[0]) *
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user