mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-08-04 06:38:26 +00:00
300 lines
11 KiB
C++
300 lines
11 KiB
C++
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include <include/args.h>
|
|
#include <include/paddlestructure.h>
|
|
#include <include/structure_layout.h>
|
|
#include <include/structure_table.h>
|
|
|
|
#include <auto_log/autolog.h>
|
|
|
|
namespace PaddleOCR {
|
|
|
|
struct PaddleStructure::STRUCTURE_PRIVATE {
|
|
std::unique_ptr<StructureTableRecognizer> table_model_;
|
|
std::unique_ptr<StructureLayoutRecognizer> layout_model_;
|
|
};
|
|
|
|
PaddleStructure::PaddleStructure() noexcept : pri_(new STRUCTURE_PRIVATE) {
|
|
if (FLAGS_layout) {
|
|
this->pri_->layout_model_.reset(new StructureLayoutRecognizer(
|
|
FLAGS_layout_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
|
|
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_layout_dict_path,
|
|
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_layout_score_threshold,
|
|
FLAGS_layout_nms_threshold));
|
|
}
|
|
if (FLAGS_table) {
|
|
this->pri_->table_model_.reset(new StructureTableRecognizer(
|
|
FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
|
|
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path,
|
|
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num,
|
|
FLAGS_table_max_len, FLAGS_merge_no_span_structure));
|
|
}
|
|
}
|
|
|
|
PaddleStructure::~PaddleStructure() { delete this->pri_; }
|
|
|
|
std::vector<StructurePredictResult>
|
|
PaddleStructure::structure(const cv::Mat &srcimg, bool layout, bool table,
|
|
bool ocr) noexcept {
|
|
cv::Mat img;
|
|
srcimg.copyTo(img);
|
|
|
|
std::vector<StructurePredictResult> structure_results;
|
|
|
|
if (layout) {
|
|
this->layout(img, structure_results);
|
|
} else {
|
|
StructurePredictResult res;
|
|
res.type = "table";
|
|
res.box.resize(4, 0.0);
|
|
res.box[2] = img.cols;
|
|
res.box[3] = img.rows;
|
|
structure_results.emplace_back(std::move(res));
|
|
}
|
|
cv::Mat roi_img;
|
|
for (size_t i = 0; i < structure_results.size(); ++i) {
|
|
// crop image
|
|
roi_img = std::move(Utility::crop_image(img, structure_results[i].box));
|
|
if (structure_results[i].type == "table" && table) {
|
|
this->table(roi_img, structure_results[i]);
|
|
} else if (ocr) {
|
|
structure_results[i].text_res =
|
|
std::move(this->ocr(roi_img, true, true, false));
|
|
}
|
|
}
|
|
|
|
return structure_results;
|
|
}
|
|
|
|
void PaddleStructure::layout(
|
|
const cv::Mat &img,
|
|
std::vector<StructurePredictResult> &structure_result) noexcept {
|
|
std::vector<double> layout_times;
|
|
this->pri_->layout_model_->Run(img, structure_result, layout_times);
|
|
|
|
this->time_info_layout[0] += layout_times[0];
|
|
this->time_info_layout[1] += layout_times[1];
|
|
this->time_info_layout[2] += layout_times[2];
|
|
}
|
|
|
|
void PaddleStructure::table(const cv::Mat &img,
|
|
StructurePredictResult &structure_result) noexcept {
|
|
// predict structure
|
|
std::vector<std::vector<std::string>> structure_html_tags;
|
|
std::vector<float> structure_scores(1, 0);
|
|
std::vector<std::vector<std::vector<int>>> structure_boxes;
|
|
std::vector<double> structure_times;
|
|
std::vector<cv::Mat> img_list(1, img);
|
|
|
|
this->pri_->table_model_->Run(img_list, structure_html_tags, structure_scores,
|
|
structure_boxes, structure_times);
|
|
|
|
this->time_info_table[0] += structure_times[0];
|
|
this->time_info_table[1] += structure_times[1];
|
|
this->time_info_table[2] += structure_times[2];
|
|
|
|
std::vector<OCRPredictResult> ocr_result;
|
|
int expand_pixel = 3;
|
|
|
|
for (size_t i = 0; i < img_list.size(); ++i) {
|
|
// det
|
|
this->det(img_list[i], ocr_result);
|
|
// crop image
|
|
std::vector<cv::Mat> rec_img_list;
|
|
std::vector<int> ocr_box;
|
|
for (size_t j = 0; j < ocr_result.size(); ++j) {
|
|
ocr_box = std::move(Utility::xyxyxyxy2xyxy(ocr_result[j].box));
|
|
ocr_box[0] = std::max(0, ocr_box[0] - expand_pixel);
|
|
ocr_box[1] = std::max(0, ocr_box[1] - expand_pixel),
|
|
ocr_box[2] = std::min(img_list[i].cols, ocr_box[2] + expand_pixel);
|
|
ocr_box[3] = std::min(img_list[i].rows, ocr_box[3] + expand_pixel);
|
|
|
|
cv::Mat crop_img = Utility::crop_image(img_list[i], ocr_box);
|
|
rec_img_list.emplace_back(std::move(crop_img));
|
|
}
|
|
// rec
|
|
this->rec(rec_img_list, ocr_result);
|
|
// rebuild table
|
|
structure_result.html = std::move(this->rebuild_table(
|
|
structure_html_tags[i], structure_boxes[i], ocr_result));
|
|
structure_result.cell_box = std::move(structure_boxes[i]);
|
|
structure_result.html_score = structure_scores[i];
|
|
}
|
|
}
|
|
|
|
std::string PaddleStructure::rebuild_table(
|
|
const std::vector<std::string> &structure_html_tags,
|
|
const std::vector<std::vector<int>> &structure_boxes,
|
|
std::vector<OCRPredictResult> &ocr_result) noexcept {
|
|
// match text in same cell
|
|
std::vector<std::vector<std::string>> matched(structure_boxes.size(),
|
|
std::vector<std::string>());
|
|
|
|
std::vector<int> ocr_box;
|
|
std::vector<int> structure_box;
|
|
for (size_t i = 0; i < ocr_result.size(); ++i) {
|
|
ocr_box = std::move(Utility::xyxyxyxy2xyxy(ocr_result[i].box));
|
|
ocr_box[0] -= 1;
|
|
ocr_box[1] -= 1;
|
|
ocr_box[2] += 1;
|
|
ocr_box[3] += 1;
|
|
std::vector<std::vector<float>> dis_list(structure_boxes.size(),
|
|
std::vector<float>(3, 100000.0));
|
|
for (size_t j = 0; j < structure_boxes.size(); ++j) {
|
|
if (structure_boxes[j].size() == 8) {
|
|
structure_box = std::move(Utility::xyxyxyxy2xyxy(structure_boxes[j]));
|
|
} else {
|
|
structure_box = structure_boxes[j];
|
|
}
|
|
dis_list[j][0] = this->dis(ocr_box, structure_box);
|
|
dis_list[j][1] = 1 - Utility::iou(ocr_box, structure_box);
|
|
dis_list[j][2] = j;
|
|
}
|
|
// find min dis idx
|
|
std::sort(dis_list.begin(), dis_list.end(),
|
|
PaddleStructure::comparison_dis);
|
|
matched[dis_list[0][2]].emplace_back(ocr_result[i].text);
|
|
}
|
|
|
|
// get pred html
|
|
std::string html_str = "";
|
|
int td_tag_idx = 0;
|
|
for (size_t i = 0; i < structure_html_tags.size(); ++i) {
|
|
if (structure_html_tags[i].find("</td>") != std::string::npos) {
|
|
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
|
|
html_str += "<td>";
|
|
}
|
|
if (matched[td_tag_idx].size() > 0) {
|
|
bool b_with = false;
|
|
if (matched[td_tag_idx][0].find("<b>") != std::string::npos &&
|
|
matched[td_tag_idx].size() > 1) {
|
|
b_with = true;
|
|
html_str += "<b>";
|
|
}
|
|
for (size_t j = 0; j < matched[td_tag_idx].size(); ++j) {
|
|
std::string content = matched[td_tag_idx][j];
|
|
if (matched[td_tag_idx].size() > 1) {
|
|
// remove blank, <b> and </b>
|
|
if (content.length() > 0 && content.at(0) == ' ') {
|
|
content = content.substr(0);
|
|
}
|
|
if (content.length() > 2 && content.substr(0, 3) == "<b>") {
|
|
content = content.substr(3);
|
|
}
|
|
if (content.length() > 4 &&
|
|
content.substr(content.length() - 4) == "</b>") {
|
|
content = content.substr(0, content.length() - 4);
|
|
}
|
|
if (content.empty()) {
|
|
continue;
|
|
}
|
|
// add blank
|
|
if (j != matched[td_tag_idx].size() - 1 &&
|
|
content.at(content.length() - 1) != ' ') {
|
|
content += ' ';
|
|
}
|
|
}
|
|
html_str += content;
|
|
}
|
|
if (b_with) {
|
|
html_str += "</b>";
|
|
}
|
|
}
|
|
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
|
|
html_str += "</td>";
|
|
} else {
|
|
html_str += structure_html_tags[i];
|
|
}
|
|
td_tag_idx += 1;
|
|
} else {
|
|
html_str += structure_html_tags[i];
|
|
}
|
|
}
|
|
return html_str;
|
|
}
|
|
|
|
float PaddleStructure::dis(const std::vector<int> &box1,
|
|
const std::vector<int> &box2) noexcept {
|
|
int x1_1 = box1[0];
|
|
int y1_1 = box1[1];
|
|
int x2_1 = box1[2];
|
|
int y2_1 = box1[3];
|
|
|
|
int x1_2 = box2[0];
|
|
int y1_2 = box2[1];
|
|
int x2_2 = box2[2];
|
|
int y2_2 = box2[3];
|
|
|
|
float dis =
|
|
abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
|
|
float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1);
|
|
float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
|
|
return dis + std::min(dis_2, dis_3);
|
|
}
|
|
|
|
void PaddleStructure::reset_timer() noexcept {
|
|
this->time_info_det = {0, 0, 0};
|
|
this->time_info_rec = {0, 0, 0};
|
|
this->time_info_cls = {0, 0, 0};
|
|
this->time_info_table = {0, 0, 0};
|
|
this->time_info_layout = {0, 0, 0};
|
|
}
|
|
|
|
void PaddleStructure::benchmark_log(int img_num) noexcept {
|
|
if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] >
|
|
0) {
|
|
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
|
|
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
|
|
FLAGS_precision, this->time_info_det, img_num);
|
|
autolog_det.report();
|
|
}
|
|
if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] >
|
|
0) {
|
|
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
|
|
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
|
|
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
|
|
this->time_info_rec, img_num);
|
|
autolog_rec.report();
|
|
}
|
|
if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] >
|
|
0) {
|
|
AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt,
|
|
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
|
|
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
|
|
this->time_info_cls, img_num);
|
|
autolog_cls.report();
|
|
}
|
|
if (this->time_info_table[0] + this->time_info_table[1] +
|
|
this->time_info_table[2] >
|
|
0) {
|
|
AutoLogger autolog_table("table", FLAGS_use_gpu, FLAGS_use_tensorrt,
|
|
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
|
|
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
|
|
this->time_info_table, img_num);
|
|
autolog_table.report();
|
|
}
|
|
if (this->time_info_layout[0] + this->time_info_layout[1] +
|
|
this->time_info_layout[2] >
|
|
0) {
|
|
AutoLogger autolog_layout("layout", FLAGS_use_gpu, FLAGS_use_tensorrt,
|
|
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
|
|
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
|
|
this->time_info_layout, img_num);
|
|
autolog_layout.report();
|
|
}
|
|
}
|
|
|
|
} // namespace PaddleOCR
|