This commit is contained in:
co63oc 2025-03-04 11:04:41 +08:00 committed by GitHub
parent f35dfef921
commit de12ece0aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 66 additions and 66 deletions

View File

@ -67,9 +67,9 @@ def get_dataloader(module_config, distributed=False):
config = copy.deepcopy(module_config)
dataset_args = config["dataset"]["args"]
if "transforms" in dataset_args:
img_transfroms = get_transforms(dataset_args.pop("transforms"))
img_transforms = get_transforms(dataset_args.pop("transforms"))
else:
img_transfroms = None
img_transforms = None
# 创建数据集
dataset_name = config["dataset"]["type"]
data_path = dataset_args.pop("data_path")
@ -91,7 +91,7 @@ def get_dataloader(module_config, distributed=False):
_dataset = get_dataset(
data_path=data_path,
module_name=dataset_name,
transform=img_transfroms,
transform=img_transforms,
dataset_args=dataset_args,
)
sampler = None

View File

@ -1570,7 +1570,7 @@ bool Clipper::ExecuteInternal() {
void Clipper::SetWindingCount(TEdge &edge) {
TEdge *e = edge.PrevInAEL;
// find the edge of the same polytype that immediately preceeds 'edge' in AEL
// find the edge of the same polytype that immediately precedes 'edge' in AEL
while (e && ((e->PolyTyp != edge.PolyTyp) || (e->WindDelta == 0)))
e = e->PrevInAEL;
if (!e) {

View File

@ -36,7 +36,7 @@
// improve performance but coordinate values are limited to the range +/- 46340
//#define use_int32
// use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance.
// use_xyz: adds a Z member to IntPoint. Adds a minor cost to performance.
//#define use_xyz
// use_lines: Enables line clipping. Adds a very minor cost to performance.

View File

@ -110,7 +110,7 @@ private:
ClsPredictResult infer_cls(const cv::Mat &origin, float thresh = 0.9);
/**
* Postprocess or sencod model to extract text
* Postprocess or second model to extract text
* @param res
* @return
*/

View File

@ -53,7 +53,7 @@ public class OCRPredictorNative {
}
public void destory() {
public void destroy() {
if (nativePointer != 0) {
release(nativePointer);
nativePointer = 0;

View File

@ -101,7 +101,7 @@ public class Predictor {
public void releaseModel() {
if (paddlePredictor != null) {
paddlePredictor.destory();
paddlePredictor.destroy();
paddlePredictor = null;
}
isLoaded = false;

View File

@ -1573,7 +1573,7 @@ bool Clipper::ExecuteInternal() noexcept {
void Clipper::SetWindingCount(TEdge &edge) noexcept {
TEdge *e = edge.PrevInAEL;
// find the edge of the same polytype that immediately preceeds 'edge' in AEL
// find the edge of the same polytype that immediately precedes 'edge' in AEL
while (e && ((e->PolyTyp != edge.PolyTyp) || (e->WindDelta == 0)))
e = e->PrevInAEL;
if (!e) {

View File

@ -390,7 +390,7 @@ void TablePostProcessor::Run(
const std::vector<int> &width_list,
const std::vector<int> &height_list) noexcept {
for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; ++batch_idx) {
// image tags and boxs
// image tags and boxes
std::vector<std::string> rec_html_tags;
std::vector<std::vector<int>> rec_boxes;

View File

@ -18,7 +18,7 @@
</p>
> **注意:**
>> 如果您在导入项目、编译或者运行过程中遇到 NDK 配置错误的提示,请打开 ` File > Project Structure > SDK Location`,修改 `Andriod SDK location` 为您本机配置的 SDK 所在路径。
>> 如果您在导入项目、编译或者运行过程中遇到 NDK 配置错误的提示,请打开 ` File > Project Structure > SDK Location`,修改 `Android SDK location` 为您本机配置的 SDK 所在路径。
4. 点击 Run 按钮,自动编译 APP 并安装到手机。(该过程会自动下载预编译的 FastDeploy Android 库 以及 模型文件,需要联网)
成功后效果如下图一APP 安装到手机;图二: APP 打开后的效果会自动识别图片中的物体并标记图三APP设置选项点击右上角的设置图片可以设置不同选项进行体验。

View File

@ -33,7 +33,7 @@ The introduction and tutorial of Paddle Serving service deployment framework ref
- [Environmental preparation](#environmental-preparation)
- [Model conversion](#model-conversion)
- [Paddle Serving pipeline deployment](#paddle-serving-pipeline-deployment)
- [Paddle Serving C++ deployment](#C++)
- [C++ Serving](#c-serving)
- [WINDOWS Users](#windows-users)
- [FAQ](#faq)
@ -247,7 +247,7 @@ The C++ service deployment is the same as python in the environment setup and da
## WINDOWS Users
Windows does not support Pipeline Serving, if we want to lauch paddle serving on Windows, we should use Web Service, for more infomation please refer to [Paddle Serving for Windows Users](https://github.com/PaddlePaddle/Serving/blob/develop/doc/Windows_Tutorial_EN.md)
Windows does not support Pipeline Serving, if we want to launch paddle serving on Windows, we should use Web Service, for more information please refer to [Paddle Serving for Windows Users](https://github.com/PaddlePaddle/Serving/blob/develop/doc/Windows_Tutorial_EN.md)
**WINDOWS user can only use version 0.5.0 CPU Mode**

View File

@ -11,7 +11,7 @@ Paper:
> Darwin Bautista, Rowel Atienza
> ECCV, 2021
Using real datasets (real) and synthetic datsets (synth) for training respectivelyand evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets.
Using real datasets (real) and synthetic datasets (synth) for training respectivelyand evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets.
- The real datasets include COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR and OpenVINO datasets.
- The synthesis datasets include MJSynth and SynthText datasets.

View File

@ -43,5 +43,5 @@ Here are the common datasets key information extraction, which are being updated
**Note** Boxes with category `Ignore` or `Others` are not visualized here.
- **Download address**
- Offical dataset: [link](https://download.openmmlab.com/mmocr/data/wildreceipt.tar)
- Official dataset: [link](https://download.openmmlab.com/mmocr/data/wildreceipt.tar)
- Dataset converted for PaddleOCR training process: [link](https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar)

View File

@ -5,7 +5,7 @@ comments: true
## Layout Analysis Dataset
Here are the common datasets of layout anlysis, which are being updated continuously. Welcome to contribute datasets.
Here are the common datasets of layout analysis, which are being updated continuously. Welcome to contribute datasets.
Most of the layout analysis datasets are object detection datasets. In addition to open source datasets, you can also label or synthesize datasets using tools such as [labelme](https://github.com/wkentaro/labelme) and so on.
@ -33,7 +33,7 @@ Most of the layout analysis datasets are object detection datasets. In addition
- **Download address**: <https://github.com/buptlihang/CDLA>
- **Note**: When you train detection model on CDLA dataset using [PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection/tree/develop), you need to remove the label `__ignore__` and `_background_`.
### 3、TableBank dataet
### 3、TableBank dataset
- **Data source**: <https://doc-analysis.github.io/tablebank-page/index.html>
- **Data introduction**: TableBank dataset contains 2 types of document: Latex (187199 training images, 7265 validation images and 5719 testing images) and Word (73383 training images 2735 validation images and 2281 testing images). Some images and their annotations as shown below.

View File

@ -13,7 +13,7 @@ Paddle provides a variety of deployment schemes to meet the deployment requireme
![img](./images/deployment_en.jpg)
PP-OCR has supported muti deployment schemes. Click the link to get the specific tutorial.
PP-OCR has supported multi deployment schemes. Click the link to get the specific tutorial.
- [Python Inference](./python_infer.en.md)
- [C++ Inference](./cpp_infer.en.md)

View File

@ -19,7 +19,7 @@ Run OCR demo in browser refer to [tutorial](https://github.com/PaddlePaddle/Fast
## Mini Program Demo
The Mini Program demo running tutorial eference
The Mini Program demo running tutorial reference
Run OCR demo in wechat miniprogram refer to [tutorial](https://github.com/PaddlePaddle/FastDeploy/tree/develop/examples/application/js/mini_program).
|demo|directory|

View File

@ -250,7 +250,7 @@ The C++ service deployment is the same as python in the environment setup and da
### WINDOWS Users
Windows does not support Pipeline Serving, if we want to lauch paddle serving on Windows, we should use Web Service, for more infomation please refer to [Paddle Serving for Windows Users](https://github.com/PaddlePaddle/Serving/blob/develop/doc/Windows_Tutorial_EN.md)
Windows does not support Pipeline Serving, if we want to launch paddle serving on Windows, we should use Web Service, for more information please refer to [Paddle Serving for Windows Users](https://github.com/PaddlePaddle/Serving/blob/develop/doc/Windows_Tutorial_EN.md)
**WINDOWS user can only use version 0.5.0 CPU Mode**

View File

@ -480,7 +480,7 @@ for idx in range(len(result)):
| enable_mkldnn | Whether to enable mkldnn | FALSE |
| use_zero_copy_run | Whether to forward by zero_copy_run | FALSE |
| lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch |
| det | Enable detction when `ppocr.ocr` func exec | TRUE |
| det | Enable detection when `ppocr.ocr` func exec | TRUE |
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
| show_log | Whether to print log| FALSE |

View File

@ -316,7 +316,7 @@ Metric:
base_metric_name: RecMetric # The base class of indicator calculation. For the output of the model, the indicator will be calculated based on this class
main_indicator: acc # The name of the indicator
key: "Student" # Select the main_indicator of this subnet as the criterion for saving the best model
ignore_space: False # whether to ignore space during evaulation
ignore_space: False # whether to ignore space during evaluation
```
Taking the above configuration as an example, the accuracy metric of the `Student` subnet will be used as the judgment metric for saving the best model.

View File

@ -130,7 +130,7 @@ After adding the four-part modules of the network, you only need to configure th
args1: args1
```
**NOTE**: More details about replace Backbone and other mudule can be found in [doc](../../algorithm/add_new_algorithm.en.md).
**NOTE**: More details about replace Backbone and other module can be found in [doc](../../algorithm/add_new_algorithm.en.md).
### 2.4 Mixed Precision Training

View File

@ -93,7 +93,7 @@ The final dataset shall have the following file structure.
### 1.3. Download data
If you do not have local dataset, you can donwload the source files of [XFUND](https://github.com/doc-analysis/XFUND) or [FUNSD](https://guillaumejaume.github.io/FUNSD) and use the scripts of [XFUND](../../ppstructure/kie/tools/trans_xfun_data.py) or [FUNSD](../../ppstructure/kie/tools/trans_funsd_label.py) for tranform them into PaddleOCR format. Then you can use the public dataset to quick experience KIE.
If you do not have local dataset, you can download the source files of [XFUND](https://github.com/doc-analysis/XFUND) or [FUNSD](https://guillaumejaume.github.io/FUNSD) and use the scripts of [XFUND](../../ppstructure/kie/tools/trans_xfun_data.py) or [FUNSD](../../ppstructure/kie/tools/trans_funsd_label.py) for transform them into PaddleOCR format. Then you can use the public dataset to quick experience KIE.
For more information about public KIE datasets, please refer to [KIE dataset tutorial](../../datasets/kie_datasets.en.md).
@ -191,7 +191,7 @@ Architecture:
name: LayoutXLMForSer
pretrained: True
mode: vi
# Assuming that n categroies are included in the dictionary file (other is included), the the num_classes is set as 2n-1
# Assuming that n categories are included in the dictionary file (other is included), the the num_classes is set as 2n-1
num_classes: &num_classes 7
PostProcess:
@ -239,7 +239,7 @@ python3 tools/train.py -c configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml
**Note:**
- Priority of `Architecture.Backbone.checkpoints` is higher than`Architecture.Backbone.pretrained`. You need to set `Architecture.Backbone.checkpoints` for model finetuning, resume and evalution. If you want to train with the NLP pretrained model, you need to set `Architecture.Backbone.pretrained` as `True` and set `Architecture.Backbone.checkpoints` as null (`null`).
- Priority of `Architecture.Backbone.checkpoints` is higher than`Architecture.Backbone.pretrained`. You need to set `Architecture.Backbone.checkpoints` for model finetuning, resume and evaluation. If you want to train with the NLP pretrained model, you need to set `Architecture.Backbone.pretrained` as `True` and set `Architecture.Backbone.checkpoints` as null (`null`).
- PaddleNLP pretrained models are used here for LayoutXLM series models, the model loading and saving logic is same as those in PaddleNLP. Therefore we do not need to set `Global.pretrained_model` or `Global.checkpoints` here.
- If you use knowledge distillation to train the LayoutXLM series models, resuming training is not supported now.
@ -280,7 +280,7 @@ Running on a DCU device requires setting the environment variable `export HIP_VI
### 3.1. Evaluation
The trained model will be saved in `Global.save_model_dir`. When evaluation, you need to set `Architecture.Backbone.checkpoints` as your model directroy. The evaluation dataset can be set by modifying the `Eval.dataset.label_file_list` field in the `configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml` file.
The trained model will be saved in `Global.save_model_dir`. When evaluation, you need to set `Architecture.Backbone.checkpoints` as your model directory. The evaluation dataset can be set by modifying the `Eval.dataset.label_file_list` field in the `configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml` file.
```bash linenums="1"
# GPU evaluation, Global.checkpoints is the weight to be tested

View File

@ -325,7 +325,7 @@ After adding the four-part modules of the network, you only need to configure th
args1: args1
```
**NOTE**: More details about replace Backbone and other mudule can be found in [doc](../../algorithm/add_new_algorithm.en.md).
**NOTE**: More details about replace Backbone and other module can be found in [doc](../../algorithm/add_new_algorithm.en.md).
### 2.4. Mixed Precision Training

View File

@ -47,7 +47,7 @@ In the non end-to-end KIE method, KIE needs at least **2 steps**. Firstly, the O
##### (1) Data
Most of the models provided in PaddleOCR are general models. In the process of text detection, the detection of adjacent text lines is generally based on the distance of the position. As shown in the figure above, when using PP-OCRv3 general English detection model for text detection, it is easy to detect the two fields representing different propoerties as one. Therefore, it is suggested to finetune a detection model according to your scenario firstly during the KIE task.
Most of the models provided in PaddleOCR are general models. In the process of text detection, the detection of adjacent text lines is generally based on the distance of the position. As shown in the figure above, when using PP-OCRv3 general English detection model for text detection, it is easy to detect the two fields representing different properties as one. Therefore, it is suggested to finetune a detection model according to your scenario firstly during the KIE task.
During data annotation, the different key information needs to be separated. Otherwise, it will increase the difficulty of subsequent KIE tasks.

View File

@ -60,7 +60,7 @@ cd PaddleOCR/ppstructure
## download model
cd inference
## Download the detection model of the ultra-lightweight Chinesse PP-OCRv3 model and unzip it
## Download the detection model of the ultra-lightweight Chinese PP-OCRv3 model and unzip it
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_det_infer.tar
## Download the recognition model of the ultra-lightweight Chinese PP-OCRv3 model and unzip it
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar

View File

@ -81,7 +81,7 @@ pip3 install pdf2docx-0.0.0-py3-none-any.whl
## 3. Quick Start using standard PDF parse
`use_pdf2docx_api` use PDF parse for layout recovery, The whl package is also provided for quick use, follow the above code, for more infomation please refer to [quickstart](../quick_start.en.md) for details.
`use_pdf2docx_api` use PDF parse for layout recovery, The whl package is also provided for quick use, follow the above code, for more information please refer to [quickstart](../quick_start.en.md) for details.
```bash linenums="1"
# install paddleocr
@ -110,7 +110,7 @@ Through layout analysis, we divided the image/PDF documents into regions, locate
We can restore the test picture through the layout information, OCR detection and recognition structure, table information, and saved pictures.
The whl package is also provided for quick use, follow the above code, for more infomation please refer to [quickstart](../quick_start.en.md) for details.
The whl package is also provided for quick use, follow the above code, for more information please refer to [quickstart](../quick_start.en.md) for details.
```bash linenums="1"
paddleocr --image_dir=ppstructure/docs/table/1.png --type=structure --recovery=true --lang='en'
@ -166,8 +166,8 @@ Field
- det_model_dirOCR detection model path
- rec_model_dirOCR recognition model path
- rec_char_dict_pathOCR recognition dict path. If the Chinese model is used, change to "../ppocr/utils/ppocr_keys_v1.txt". And if you trained the model on your own dataset, change to the trained dictionary
- table_model_dirtabel recognition model path
- table_char_dict_pathtabel recognition dict path. If the Chinese model is used, no need to change
- table_model_dirtable recognition model path
- table_char_dict_pathtable recognition dict path. If the Chinese model is used, no need to change
- layout_model_dirlayout analysis model path
- layout_dict_pathlayout analysis dict path. If the Chinese model is used, change to "../ppocr/utils/dict/layout_dict/layout_cdla_dict.txt"
- recoverywhether to enable layout of recovery, default False

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refered from:
This code is referred from:
https://github.com/songdejia/EAST/blob/master/data_utils.py
"""
import math
@ -241,7 +241,7 @@ class EASTProcessTrain(object):
score_map = np.zeros((h, w), dtype=np.uint8)
# (x1, y1, ..., x4, y4, short_edge_norm)
geo_map = np.zeros((h, w, 9), dtype=np.float32)
# mask used during traning, to ignore some hard areas
# mask used during training, to ignore some hard areas
training_mask = np.ones((h, w), dtype=np.uint8)
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]

View File

@ -60,7 +60,7 @@ class DecodeImage(object):
class NormalizeImage(object):
"""normalize image such as substract mean, divide std"""
"""normalize image such as subtract mean, divide std"""
def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
if isinstance(scale, str):

View File

@ -763,7 +763,7 @@ class PGProcessTrain(object):
def prepare_text_label(self, label_str, Lexicon_Table):
"""
Prepare text lablel by given Lexicon_Table.
Prepare text label by given Lexicon_Table.
"""
if len(Lexicon_Table) == 36:
return label_str.lower()

View File

@ -413,7 +413,7 @@ class SARRecResizeImg(object):
class PRENResizeImg(object):
def __init__(self, image_shape, **kwargs):
"""
Accroding to original paper's realization, it's a hard resize method here.
According to original paper's realization, it's a hard resize method here.
So maybe you should optimize it to fit for your task better.
"""
self.dst_h, self.dst_w = image_shape

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This part code is refered from:
This part code is referred from:
https://github.com/songdejia/EAST/blob/master/data_utils.py
"""
import math

View File

@ -206,7 +206,7 @@ class MultiScaleDataSet(SimpleDataSet):
return data
def __getitem__(self, properties):
# properites is a tuple, contains (width, height, index)
# properties is a tuple, contains (width, height, index)
img_height = properties[1]
idx = properties[2]
if self.ds_width and properties[3] is not None:

View File

@ -30,7 +30,7 @@ class DBLoss(nn.Layer):
"""
Differentiable Binarization (DB) Loss Function
args:
param (dict): the super paramter for DB Loss
param (dict): the super parameter for DB Loss
"""
def __init__(

View File

@ -1016,7 +1016,7 @@ class CTCDKDLoss(nn.Layer):
targets = targets.astype("int32")
res = F.one_hot(targets, num_classes=11465)
mask = paddle.clip(paddle.sum(res, axis=1), 0, 1)
mask[:, 0] = 0 # ingore ctc blank label
mask[:, 0] = 0 # ignore ctc blank label
return mask
def forward(self, logits_student, logits_teacher, targets, mask=None):
@ -1128,7 +1128,7 @@ class KLCTCLogits(nn.Layer):
tea_out *= blank_mask
return self.forward_meanlog(stu_out, tea_out)
elif self.mode == "ctcdkd":
# ingore ctc blank logits
# ignore ctc blank logits
blank_mask = paddle.ones_like(stu_out)
blank_mask.stop_gradient = True
blank_mask[:, :, 0] = -1

View File

@ -34,7 +34,7 @@ class CTMetric(object):
def __call__(self, preds, batch, **kwargs):
# NOTE: only support bs=1 now, as the label length of different sample is Unequal
assert len(preds) == 1, "CentripetalText test now only suuport batch_size=1."
assert len(preds) == 1, "CentripetalText test now only support batch_size=1."
label = batch[2]
text = batch[3]
pred = preds[0]["points"]

View File

@ -34,10 +34,10 @@ class BaseModel(nn.Layer):
super(BaseModel, self).__init__()
in_channels = config.get("in_channels", 3)
model_type = config["model_type"]
# build transfrom,
# for rec, transfrom can be TPS,None
# for det and cls, transfrom shoule to be None,
# if you make model differently, you can use transfrom in det and cls
# build transform,
# for rec, transform can be TPS,None
# for det and cls, transform shoule to be None,
# if you make model differently, you can use transform in det and cls
if "Transform" not in config or config["Transform"] is None:
self.use_transform = False
else:

View File

@ -428,7 +428,7 @@ class LKPAN(nn.Layer):
class ASFBlock(nn.Layer):
"""
This code is refered from:
This code is referred from:
https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
"""

View File

@ -128,7 +128,7 @@ class GA_SPIN_Transformer(nn.Layer):
default_type (int): the K chromatic space,
set it to 3/5/6 depend on the complexity of transformation intensities
loc_lr (float): learning rate of location network
stn (bool): whther to use stn.
stn (bool): whether to use stn.
"""
super(GA_SPIN_Transformer, self).__init__()

View File

@ -49,7 +49,7 @@ class OneCycleDecay(LRScheduler):
"""
One Cycle learning rate decay
A learning rate which can be referred in https://arxiv.org/abs/1708.07120
Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
Code referred in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
"""
def __init__(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refered from:
This code is referred from:
https://github.com/shengtao96/CentripetalText/blob/main/test.py
"""

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refered from:
This code is referred from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
"""
from __future__ import absolute_import

View File

@ -1,6 +1,6 @@
"""
Locality aware nms.
This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
This code is referred from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
"""
import numpy as np
@ -71,7 +71,7 @@ def standard_nms(S, thres):
def standard_nms_inds(S, thres):
"""
Standard nms, retun inds.
Standard nms, return inds.
"""
order = np.argsort(S[:, 8])[::-1]
keep = []
@ -158,8 +158,8 @@ def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
else:
weight = 1
boxes[pos, 8] = weight * boxes[pos, 8]
# if box score falls below thresold, discard the box by
# swaping last box update N
# if box score falls below threshold, discard the box by
# swapping last box update N
if boxes[pos, 8] < threshold:
boxes[pos, :] = boxes[N - 1, :]
inds[pos] = inds[N - 1]

View File

@ -80,8 +80,8 @@ class BaseRecLabelDecode(object):
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continous chinese characters (e.g., 你好啊)
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
- 'cn': continuous chinese characters (e.g., 你好啊)
- 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
"""
state = None
@ -105,7 +105,7 @@ class BaseRecLabelDecode(object):
and state == "en&num"
and c_i + 1 < len(text)
and bool(re.search("[0-9]", text[c_i + 1]))
): # grouping floting number
): # grouping floating number
c_state = "en&num"
if (
char == "-" and state == "en&num"

View File

@ -1,5 +1,5 @@
## Dictionary and Corpus
Dictionary files (usually character level vocabulary) are included here for easier configuration. Corpus contributed by OSS contirbutors are listed here, please respect copyrights when using them at your own risk.
Dictionary files (usually character level vocabulary) are included here for easier configuration. Corpus contributed by OSS contributors are listed here, please respect copyrights when using them at your own risk.
- Burmese corpus: https://github.com/1chimaruGin/BurmeseCorpus

View File

@ -386,7 +386,7 @@ def generate_pivot_list_curved(
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
# use decoder to filter background points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True
@ -464,7 +464,7 @@ def generate_pivot_list_horizontal(
pos_list_sorted, _ = sort_with_direction(pos_list_final, f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
# use decoder to filter background points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True

View File

@ -83,7 +83,7 @@ def _download(url, save_path):
"{}!".format(url, req.status_code)
)
# For protecting download interupted, download to
# For protecting download interrupted, download to
# tmp_file firstly, move tmp_file to save_path
# after download finished
tmp_file = save_path + ".tmp"

View File

@ -180,7 +180,7 @@ def set_seed(seed=1024):
def check_install(module_name, install_name):
spec = importlib.util.find_spec(module_name)
if spec is None:
print(f"Warnning! The {module_name} module is NOT installed")
print(f"Warning! The {module_name} module is NOT installed")
print(
f"Try install {module_name} module automatically. You can also try to install manually by pip install {install_name}."
)