This commit is contained in:
co63oc 2025-03-05 19:09:08 +08:00 committed by GitHub
parent 9e7a1f4cc1
commit e061055808
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 36 additions and 36 deletions

View File

@ -31,7 +31,7 @@ PaddleOCR is being oversight by a [PMC](https://github.com/PaddlePaddle/PaddleOC
- **🔥 2024.10.18 release PaddleOCR v2.9, including**:
- PaddleX, an All-in-One development tool based on PaddleOCR's advanced technology, supports low-code full-process development capabilities in the OCR field:
- 🎨 [**Rich Model One-Click Call**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/quick_start.html): Integrates **17 models** related to text image intelligent analysis, general OCR, general layout parsing, table recognition, formula recognition, and seal recognition into 6 pipelines, which can be quickly experienced through a simple **Python API one-click call**. In addition, the same set of APIs also supports a total of **200+ models** in image classification, object detection, image segmentation, and time series forcasting, forming 20+ single-function modules, making it convenient for developers to use **model combinations**.
- 🎨 [**Rich Model One-Click Call**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/quick_start.html): Integrates **17 models** related to text image intelligent analysis, general OCR, general layout parsing, table recognition, formula recognition, and seal recognition into 6 pipelines, which can be quickly experienced through a simple **Python API one-click call**. In addition, the same set of APIs also supports a total of **200+ models** in image classification, object detection, image segmentation, and time series forecasting, forming 20+ single-function modules, making it convenient for developers to use **model combinations**.
- 🚀 [**High Efficiency and Low barrier of entry**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/overview.html): Provides two methods based on **unified commands** and **GUI** to achieve simple and efficient use, combination, and customization of models. Supports multiple deployment methods such as **high-performance inference, service-oriented deployment, and edge deployment**. Additionally, for various mainstream hardware such as **NVIDIA GPU, Kunlunxin XPU, Ascend NPU, Cambricon MLU, and Haiguang DCU**, models can be developed with **seamless switching**.

View File

@ -24,12 +24,12 @@ def shrink_polygon_pyclipper(polygon, shrink_ratio):
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrinked = padding.Execute(-distance)
if shrinked == []:
shrinked = np.array(shrinked)
shrunk = padding.Execute(-distance)
if shrunk == []:
shrunk = np.array(shrunk)
else:
shrinked = np.array(shrinked[0]).reshape(-1, 2)
return shrinked
shrunk = np.array(shrunk[0]).reshape(-1, 2)
return shrunk
class MakeShrinkMap:
@ -69,12 +69,12 @@ class MakeShrinkMap:
cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
else:
shrinked = self.shrink_func(polygon, self.shrink_ratio)
if shrinked.size == 0:
shrunk = self.shrink_func(polygon, self.shrink_ratio)
if shrunk.size == 0:
cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
continue
cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)
cv2.fillPoly(gt, [shrunk.astype(np.int32)], 1)
data["shrink_map"] = gt
data["shrink_mask"] = mask

View File

@ -15,10 +15,10 @@
#include <gflags/gflags.h>
// common args
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU.");
DEFINE_bool(use_gpu, false, "Inferring with GPU or CPU.");
DEFINE_bool(use_tensorrt, false, "Whether use tensorrt.");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute.");
DEFINE_int32(gpu_mem, 4000, "GPU id when infering with GPU.");
DEFINE_int32(gpu_mem, 4000, "GPU id when inferring with GPU.");
DEFINE_int32(cpu_threads, 10, "Num of threads with CPU.");
DEFINE_bool(enable_mkldnn, false, "Whether use mkldnn with CPU.");
DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8");

View File

@ -53,7 +53,7 @@ python3 -m pip install paddleocr
![](./images/8dca91f016884e16ad9216d416da72ea08190f97d87b4be883f15079b7ebab9a.jpeg)
```bash linenums="1"
paddleocr --lang=ch --det=Fase --image_dir=data
paddleocr --lang=ch --det=False --image_dir=data
```
得到如下测试结果:

View File

@ -182,7 +182,7 @@ python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o G
```python linenums="1"
cd ./pretrained_models/
# transform teacher params in best_accuracy.pdparams into teacher_dml.paramers
# transform teacher params in best_accuracy.pdparams into teacher_dml.pdparams
import paddle
# load pretrained model

View File

@ -32,7 +32,7 @@ PaddleOCR is being oversight by a [PMC](https://github.com/PaddlePaddle/PaddleOC
- **🔥 2024.10.18 release PaddleOCR v2.9, including**:
- PaddleX, an All-in-One development tool based on PaddleOCR's advanced technology, supports low-code full-process development capabilities in the OCR field:
- 🎨 [**Rich Model One-Click Call**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/quick_start.html): Integrates **17 models** related to text image intelligent analysis, general OCR, general layout parsing, table recognition, formula recognition, and seal recognition into 6 pipelines, which can be quickly experienced through a simple **Python API one-click call**. In addition, the same set of APIs also supports a total of **200+ models** in image classification, object detection, image segmentation, and time series forcasting, forming 20+ single-function modules, making it convenient for developers to use **model combinations**.
- 🎨 [**Rich Model One-Click Call**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/quick_start.html): Integrates **17 models** related to text image intelligent analysis, general OCR, general layout parsing, table recognition, formula recognition, and seal recognition into 6 pipelines, which can be quickly experienced through a simple **Python API one-click call**. In addition, the same set of APIs also supports a total of **200+ models** in image classification, object detection, image segmentation, and time series forecasting, forming 20+ single-function modules, making it convenient for developers to use **model combinations**.
- 🚀 [**High Efficiency and Low barrier of entry**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/overview.html): Provides two methods based on **unified commands** and **GUI** to achieve simple and efficient use, combination, and customization of models. Supports multiple deployment methods such as **high-performance inference, service-oriented deployment, and edge deployment**. Additionally, for various mainstream hardware such as **NVIDIA GPU, Kunlunxin XPU, Ascend NPU, Cambricon MLU, and Haiguang DCU**, models can be developed with **seamless switching**.

View File

@ -2,7 +2,7 @@
The All-in-One development tool [PaddleX](https://github.com/PaddlePaddle/PaddleX/tree/release/3.0-beta1), based on the advanced technology of PaddleOCR, supports **low-code full-process** development capabilities in the OCR field. Through low-code development, simple and efficient model use, combination, and customization can be achieved. This will significantly **reduce the time consumption** of model development, **lower its development difficulty**, and greatly accelerate the application and promotion speed of models in the industry. Features include:
* 🎨 [**Rich Model One-Click Call**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/quick_start.html): Integrates **17 models** related to text image intelligent analysis, general OCR, general layout parsing, table recognition, formula recognition, and seal recognition into 6 pipelines, which can be quickly experienced through a simple **Python API one-click call**. In addition, the same set of APIs also supports a total of **200+ models** in image classification, object detection, image segmentation, and time series forcasting, forming 20+ single-function modules, making it convenient for developers to use **model combinations**.
* 🎨 [**Rich Model One-Click Call**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/quick_start.html): Integrates **17 models** related to text image intelligent analysis, general OCR, general layout parsing, table recognition, formula recognition, and seal recognition into 6 pipelines, which can be quickly experienced through a simple **Python API one-click call**. In addition, the same set of APIs also supports a total of **200+ models** in image classification, object detection, image segmentation, and time series forecasting, forming 20+ single-function modules, making it convenient for developers to use **model combinations**.
* 🚀 [**High Efficiency and Low barrier of entry**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/overview.html): Provides two methods based on **unified commands** and **GUI** to achieve simple and efficient use, combination, and customization of models. Supports multiple deployment methods such as **high-performance inference, service-oriented deployment, and edge deployment**. Additionally, for various mainstream hardware such as **NVIDIA GPU, Kunlunxin XPU, Ascend NPU, Cambricon MLU, and Haiguang DCU**, models can be developed with **seamless switching**.

View File

@ -60,4 +60,4 @@ Here we have sorted out some Chinese OCR training and prediction tricks, which a
- **Tips**
There are two possible methods for space recognition. (1) Optimize the text detection. For spliting the text at the space in detection results, it needs to divide the text line with space into many segments when label the data for detection. (2) Optimize the text recognition. The space character is introduced into the recognition dictionary. Label the blank line in the training data for text recognition. In addition, we can also concat multiple word lines to synthesize the training data with spaces. PaddleOCR currently uses the second method.
There are two possible methods for space recognition. (1) Optimize the text detection. For splitting the text at the space in detection results, it needs to divide the text line with space into many segments when label the data for detection. (2) Optimize the text recognition. The space character is introduced into the recognition dictionary. Label the blank line in the training data for text recognition. In addition, we can also concat multiple word lines to synthesize the training data with spaces. PaddleOCR currently uses the second method.

View File

@ -10,7 +10,7 @@ Paddle provides a variety of deployment schemes to meet the deployment requireme
![img](./images/deployment-20240704135743247.png)
PP-OCR has supported muti deployment schemes. Click the link to get the specific tutorial.
PP-OCR has supported multi deployment schemes. Click the link to get the specific tutorial.
- [Python Inference](./python_infer.en.md)
- [C++ Inference](./cpp_infer.en.md)

View File

@ -9,7 +9,7 @@ hide:
- **🔥 2024.10.18 release PaddleOCR v2.9, including**:
* PaddleX, an All-in-One development tool based on PaddleOCR's advanced technology, supports low-code full-process development capabilities in the OCR field:
* 🎨 [**Rich Model One-Click Call**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/quick_start.html): Integrates **17 models** related to text image intelligent analysis, general OCR, general layout parsing, table recognition, formula recognition, and seal recognition into 6 pipelines, which can be quickly experienced through a simple **Python API one-click call**. In addition, the same set of APIs also supports a total of **200+ models** in image classification, object detection, image segmentation, and time series forcasting, forming 20+ single-function modules, making it convenient for developers to use **model combinations**.
* 🎨 [**Rich Model One-Click Call**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/quick_start.html): Integrates **17 models** related to text image intelligent analysis, general OCR, general layout parsing, table recognition, formula recognition, and seal recognition into 6 pipelines, which can be quickly experienced through a simple **Python API one-click call**. In addition, the same set of APIs also supports a total of **200+ models** in image classification, object detection, image segmentation, and time series forecasting, forming 20+ single-function modules, making it convenient for developers to use **model combinations**.
* 🚀 [**High Efficiency and Low barrier of entry**](https://paddlepaddle.github.io/PaddleOCR/latest/en/paddlex/overview.html): Provides two methods based on **unified commands** and **GUI** to achieve simple and efficient use, combination, and customization of models. Supports multiple deployment methods such as **high-performance inference, service-oriented deployment, and edge deployment**. Additionally, for various mainstream hardware such as **NVIDIA GPU, Kunlunxin XPU, Ascend NPU, Cambricon MLU, and Haiguang DCU**, models can be developed with **seamless switching**.

View File

@ -169,7 +169,7 @@ class EASTProcessTrain(object):
used for generate the score map
:param poly: the text poly
:param r: r in the paper
:return: the shrinked poly
:return: the shrunk poly
"""
# shrink ratio
R = 0.3

View File

@ -88,17 +88,17 @@ class MakePseGt(object):
subject = [tuple(l) for l in poly]
pco = pyclipper.PyclipperOffset()
pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrinked = np.array(pco.Execute(-distance))
shrunk = np.array(pco.Execute(-distance))
if len(shrinked) == 0 or shrinked.size == 0:
if len(shrunk) == 0 or shrunk.size == 0:
if ignore_tags is not None:
ignore_tags[i] = True
continue
try:
shrinked = np.array(shrinked[0]).reshape(-1, 2)
shrunk = np.array(shrunk[0]).reshape(-1, 2)
except:
if ignore_tags is not None:
ignore_tags[i] = True
continue
cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
cv2.fillPoly(text_kernel, [shrunk.astype(np.int32)], i + 1)
return text_kernel, ignore_tags

View File

@ -64,7 +64,7 @@ class MakeShrinkMap(object):
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrinked = []
shrunk = []
# Increase the shrink ratio every time we get multiple polygon returned back
possible_ratios = np.arange(self.shrink_ratio, 1, self.shrink_ratio)
@ -77,18 +77,18 @@ class MakeShrinkMap(object):
* (1 - np.power(ratio, 2))
/ polygon_shape.length
)
shrinked = padding.Execute(-distance)
if len(shrinked) == 1:
shrunk = padding.Execute(-distance)
if len(shrunk) == 1:
break
if shrinked == []:
if shrunk == []:
cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
continue
for each_shirnk in shrinked:
shirnk = np.array(each_shirnk).reshape(-1, 2)
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
for each_shrink in shrunk:
shrink = np.array(each_shrink).reshape(-1, 2)
cv2.fillPoly(gt, [shrink.astype(np.int32)], 1)
data["shrink_map"] = gt
data["shrink_mask"] = mask

View File

@ -48,13 +48,13 @@ class AsterHead(nn.Layer):
in_channels, out_channels, sDim, attDim, max_len_labels
)
self.time_step = time_step
self.embeder = Embedding(self.time_step, in_channels)
self.embedder = Embedding(self.time_step, in_channels)
self.beam_width = beam_width
self.eos = self.num_classes - 3
def forward(self, x, targets=None, embed=None):
return_dict = {}
embedding_vectors = self.embeder(x)
embedding_vectors = self.embedder(x)
if self.training:
rec_targets, rec_lengths, _ = targets

View File

@ -166,7 +166,7 @@ class Encoder(nn.Layer):
),
)
)
self.processer = PrePostProcessLayer(
self.processor = PrePostProcessLayer(
preprocess_cmd, d_model, prepostprocess_dropout
)
@ -174,7 +174,7 @@ class Encoder(nn.Layer):
for encoder_layer in self.encoder_layers:
enc_output = encoder_layer(enc_input, attn_bias)
enc_input = enc_output
enc_output = self.processer(enc_output)
enc_output = self.processor(enc_output)
return enc_output

View File

@ -508,7 +508,7 @@ class APP_Image2Doc(QWidget):
self.pb.setRange(0, max)
def handleEndsignalSignal(self):
# enble buttons
# enable buttons
self.openFileButton.setEnabled(True)
self.startCNButton.setEnabled(True)
self.startENButton.setEnabled(True)

View File

@ -66,7 +66,7 @@ function func_cpp_inference(){
for batch_size in ${cpp_batch_size_list[*]}; do
precision="fp32"
if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then
precison="int8"
precision="int8"
fi
_save_log_path="${_log_path}/cpp_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}")