PaddleOCR/deploy/cpp_infer/src/paddlestructure.cpp

290 lines
10 KiB
C++
Raw Normal View History

2022-08-10 06:05:02 +00:00
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/args.h>
#include <include/paddlestructure.h>
#include "auto_log/autolog.h"
namespace PaddleOCR {
PaddleStructure::PaddleStructure() {
2022-09-19 09:37:41 +00:00
if (FLAGS_layout) {
this->layout_model_.reset(new StructureLayoutRecognizer(
2022-09-19 09:37:41 +00:00
FLAGS_layout_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_layout_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_layout_score_threshold,
FLAGS_layout_nms_threshold));
2022-09-19 09:37:41 +00:00
}
2022-08-10 06:05:02 +00:00
if (FLAGS_table) {
this->table_model_.reset(new StructureTableRecognizer(
2022-08-10 06:05:02 +00:00
FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num,
FLAGS_table_max_len, FLAGS_merge_no_span_structure));
2022-08-10 06:05:02 +00:00
}
}
2022-08-10 06:05:02 +00:00
2022-09-19 09:37:41 +00:00
std::vector<StructurePredictResult>
PaddleStructure::structure(cv::Mat srcimg, bool layout, bool table, bool ocr) {
cv::Mat img;
srcimg.copyTo(img);
2022-08-10 06:05:02 +00:00
2022-09-19 09:37:41 +00:00
std::vector<StructurePredictResult> structure_results;
2022-08-10 06:05:02 +00:00
2022-09-19 09:37:41 +00:00
if (layout) {
this->layout(img, structure_results);
} else {
StructurePredictResult res;
res.type = "table";
2022-09-20 03:40:05 +00:00
res.box = std::vector<float>(4, 0.0);
2022-09-19 09:37:41 +00:00
res.box[2] = img.cols;
res.box[3] = img.rows;
structure_results.push_back(res);
2022-08-10 06:05:02 +00:00
}
2022-09-19 09:37:41 +00:00
cv::Mat roi_img;
for (int i = 0; i < structure_results.size(); i++) {
// crop image
roi_img = Utility::crop_image(img, structure_results[i].box);
if (structure_results[i].type == "table" && table) {
this->table(roi_img, structure_results[i]);
} else if (ocr) {
structure_results[i].text_res = this->ocr(roi_img, true, true, false);
2022-08-10 06:05:02 +00:00
}
}
2022-09-19 09:37:41 +00:00
2022-08-10 06:05:02 +00:00
return structure_results;
}
2022-08-10 06:05:02 +00:00
2022-09-19 09:37:41 +00:00
void PaddleStructure::layout(
cv::Mat img, std::vector<StructurePredictResult> &structure_result) {
std::vector<double> layout_times;
this->layout_model_->Run(img, structure_result, layout_times);
this->time_info_layout[0] += layout_times[0];
this->time_info_layout[1] += layout_times[1];
this->time_info_layout[2] += layout_times[2];
}
2022-08-10 06:05:02 +00:00
void PaddleStructure::table(cv::Mat img,
2022-09-19 09:37:41 +00:00
StructurePredictResult &structure_result) {
2022-08-10 06:05:02 +00:00
// predict structure
std::vector<std::vector<std::string>> structure_html_tags;
std::vector<float> structure_scores(1, 0);
2022-09-09 12:24:08 +00:00
std::vector<std::vector<std::vector<int>>> structure_boxes;
2022-09-19 09:37:41 +00:00
std::vector<double> structure_times;
2022-08-10 06:05:02 +00:00
std::vector<cv::Mat> img_list;
img_list.push_back(img);
2022-09-19 09:37:41 +00:00
this->table_model_->Run(img_list, structure_html_tags, structure_scores,
structure_boxes, structure_times);
this->time_info_table[0] += structure_times[0];
this->time_info_table[1] += structure_times[1];
this->time_info_table[2] += structure_times[2];
2022-08-10 06:05:02 +00:00
std::vector<OCRPredictResult> ocr_result;
std::string html;
int expand_pixel = 3;
for (int i = 0; i < img_list.size(); i++) {
// det
2022-09-19 09:37:41 +00:00
this->det(img_list[i], ocr_result);
2022-08-10 06:05:02 +00:00
// crop image
std::vector<cv::Mat> rec_img_list;
2022-09-09 12:24:08 +00:00
std::vector<int> ocr_box;
2022-08-10 06:05:02 +00:00
for (int j = 0; j < ocr_result.size(); j++) {
2022-09-09 12:24:08 +00:00
ocr_box = Utility::xyxyxyxy2xyxy(ocr_result[j].box);
2022-09-20 03:40:05 +00:00
ocr_box[0] = std::max(0, ocr_box[0] - expand_pixel);
ocr_box[1] = std::max(0, ocr_box[1] - expand_pixel),
ocr_box[2] = std::min(img_list[i].cols, ocr_box[2] + expand_pixel);
ocr_box[3] = std::min(img_list[i].rows, ocr_box[3] + expand_pixel);
2022-09-09 12:24:08 +00:00
cv::Mat crop_img = Utility::crop_image(img_list[i], ocr_box);
2022-08-10 06:05:02 +00:00
rec_img_list.push_back(crop_img);
}
// rec
2022-09-19 09:37:41 +00:00
this->rec(rec_img_list, ocr_result);
2022-08-10 06:05:02 +00:00
// rebuild table
html = this->rebuild_table(structure_html_tags[i], structure_boxes[i],
ocr_result);
structure_result.html = html;
2022-09-09 12:24:08 +00:00
structure_result.cell_box = structure_boxes[i];
2022-08-10 06:05:02 +00:00
structure_result.html_score = structure_scores[i];
}
}
2022-08-10 06:05:02 +00:00
2022-09-09 12:24:08 +00:00
std::string
PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags,
std::vector<std::vector<int>> structure_boxes,
std::vector<OCRPredictResult> &ocr_result) {
2022-08-10 06:05:02 +00:00
// match text in same cell
2022-09-20 03:40:05 +00:00
std::vector<std::vector<std::string>> matched(structure_boxes.size(),
std::vector<std::string>());
2022-08-10 06:05:02 +00:00
2022-09-09 12:24:08 +00:00
std::vector<int> ocr_box;
std::vector<int> structure_box;
2022-08-10 06:05:02 +00:00
for (int i = 0; i < ocr_result.size(); i++) {
2022-09-09 12:24:08 +00:00
ocr_box = Utility::xyxyxyxy2xyxy(ocr_result[i].box);
ocr_box[0] -= 1;
ocr_box[1] -= 1;
ocr_box[2] += 1;
ocr_box[3] += 1;
2022-08-10 06:05:02 +00:00
std::vector<std::vector<float>> dis_list(structure_boxes.size(),
std::vector<float>(3, 100000.0));
for (int j = 0; j < structure_boxes.size(); j++) {
2022-09-09 12:24:08 +00:00
if (structure_boxes[i].size() == 8) {
structure_box = Utility::xyxyxyxy2xyxy(structure_boxes[j]);
} else {
structure_box = structure_boxes[j];
}
dis_list[j][0] = this->dis(ocr_box, structure_box);
2022-09-19 09:37:41 +00:00
dis_list[j][1] = 1 - Utility::iou(ocr_box, structure_box);
2022-08-10 06:05:02 +00:00
dis_list[j][2] = j;
}
// find min dis idx
std::sort(dis_list.begin(), dis_list.end(),
PaddleStructure::comparison_dis);
matched[dis_list[0][2]].push_back(ocr_result[i].text);
}
2022-09-09 12:24:08 +00:00
2022-08-10 06:05:02 +00:00
// get pred html
std::string html_str = "";
int td_tag_idx = 0;
for (int i = 0; i < structure_html_tags.size(); i++) {
if (structure_html_tags[i].find("</td>") != std::string::npos) {
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
html_str += "<td>";
}
if (matched[td_tag_idx].size() > 0) {
bool b_with = false;
if (matched[td_tag_idx][0].find("<b>") != std::string::npos &&
matched[td_tag_idx].size() > 1) {
b_with = true;
html_str += "<b>";
}
for (int j = 0; j < matched[td_tag_idx].size(); j++) {
std::string content = matched[td_tag_idx][j];
if (matched[td_tag_idx].size() > 1) {
// remove blank, <b> and </b>
if (content.length() > 0 && content.at(0) == ' ') {
content = content.substr(0);
}
if (content.length() > 2 && content.substr(0, 3) == "<b>") {
content = content.substr(3);
}
if (content.length() > 4 &&
content.substr(content.length() - 4) == "</b>") {
content = content.substr(0, content.length() - 4);
}
if (content.empty()) {
continue;
}
// add blank
if (j != matched[td_tag_idx].size() - 1 &&
content.at(content.length() - 1) != ' ') {
content += ' ';
}
}
html_str += content;
}
if (b_with) {
html_str += "</b>";
}
}
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
html_str += "</td>";
} else {
html_str += structure_html_tags[i];
}
td_tag_idx += 1;
} else {
html_str += structure_html_tags[i];
}
}
return html_str;
}
2022-09-09 12:24:08 +00:00
float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) {
int x1_1 = box1[0];
int y1_1 = box1[1];
int x2_1 = box1[2];
int y2_1 = box1[3];
2022-08-10 06:05:02 +00:00
2022-09-09 12:24:08 +00:00
int x1_2 = box2[0];
int y1_2 = box2[1];
int x2_2 = box2[2];
int y2_2 = box2[3];
2022-08-10 06:05:02 +00:00
float dis =
abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1);
float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
2022-09-20 03:40:05 +00:00
return dis + std::min(dis_2, dis_3);
2022-08-10 06:05:02 +00:00
}
2022-09-19 09:37:41 +00:00
void PaddleStructure::reset_timer() {
this->time_info_det = {0, 0, 0};
this->time_info_rec = {0, 0, 0};
this->time_info_cls = {0, 0, 0};
this->time_info_table = {0, 0, 0};
this->time_info_layout = {0, 0, 0};
}
void PaddleStructure::benchmark_log(int img_num) {
if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] >
0) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, this->time_info_det, img_num);
autolog_det.report();
}
if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] >
0) {
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
this->time_info_rec, img_num);
autolog_rec.report();
}
if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] >
0) {
AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_cls, img_num);
autolog_cls.report();
}
if (this->time_info_table[0] + this->time_info_table[1] +
this->time_info_table[2] >
0) {
AutoLogger autolog_table("table", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_table, img_num);
autolog_table.report();
}
if (this->time_info_layout[0] + this->time_info_layout[1] +
this->time_info_layout[2] >
0) {
AutoLogger autolog_layout("layout", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_layout, img_num);
autolog_layout.report();
}
}
} // namespace PaddleOCR