mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-11-03 19:29:18 +00:00 
			
		
		
		
	refine
This commit is contained in:
		
							parent
							
								
									7c8b2c8d19
								
							
						
					
					
						commit
						16c247ac46
					
				@ -1,13 +1,12 @@
 | 
			
		||||
Global:
 | 
			
		||||
  use_gpu: true
 | 
			
		||||
  epoch_num: 40
 | 
			
		||||
  epoch_num: 50
 | 
			
		||||
  log_smooth_window: 20
 | 
			
		||||
  print_batch_step: 5
 | 
			
		||||
  save_model_dir: ./output/table_mv3/
 | 
			
		||||
  save_epoch_step: 3
 | 
			
		||||
  # evaluation is run every 5000 iterations after the 4000th iteration
 | 
			
		||||
  save_epoch_step: 5
 | 
			
		||||
  # evaluation is run every 400 iterations after the 0th iteration
 | 
			
		||||
  eval_batch_step: [0, 400]
 | 
			
		||||
  # if pretrained_model is saved in static mode, load_static_weights must set to True
 | 
			
		||||
  cal_metric_during_train: True
 | 
			
		||||
  pretrained_model: 
 | 
			
		||||
  checkpoints: 
 | 
			
		||||
@ -18,19 +17,20 @@ Global:
 | 
			
		||||
  character_dict_path: ppocr/utils/dict/table_structure_dict.txt
 | 
			
		||||
  character_type: en
 | 
			
		||||
  max_text_length: 100
 | 
			
		||||
  max_elem_length: 800
 | 
			
		||||
  max_elem_length: 500
 | 
			
		||||
  max_cell_num: 500
 | 
			
		||||
  infer_mode: False
 | 
			
		||||
  process_total_num: 0
 | 
			
		||||
  process_cut_num: 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Optimizer:
 | 
			
		||||
  name: Adam
 | 
			
		||||
  beta1: 0.9
 | 
			
		||||
  beta2: 0.999
 | 
			
		||||
  clip_norm: 5.0
 | 
			
		||||
  lr:
 | 
			
		||||
    learning_rate: 0.0001
 | 
			
		||||
    learning_rate: 0.001
 | 
			
		||||
  regularizer:
 | 
			
		||||
    name: 'L2'
 | 
			
		||||
    factor: 0.00000
 | 
			
		||||
@ -41,12 +41,12 @@ Architecture:
 | 
			
		||||
  Backbone:
 | 
			
		||||
    name: MobileNetV3
 | 
			
		||||
    scale: 1.0
 | 
			
		||||
    model_name: large
 | 
			
		||||
    model_name: small
 | 
			
		||||
    disable_se: True
 | 
			
		||||
  Head:
 | 
			
		||||
    name: TableAttentionHead  # AttentionHead
 | 
			
		||||
    hidden_size: 256 #
 | 
			
		||||
    name: TableAttentionHead
 | 
			
		||||
    hidden_size: 256
 | 
			
		||||
    l2_decay: 0.00001
 | 
			
		||||
#     loc_type: 1
 | 
			
		||||
    loc_type: 2
 | 
			
		||||
 | 
			
		||||
Loss:
 | 
			
		||||
@ -86,7 +86,7 @@ Train:
 | 
			
		||||
    shuffle: True
 | 
			
		||||
    batch_size_per_card: 32
 | 
			
		||||
    drop_last: True
 | 
			
		||||
    num_workers: 4
 | 
			
		||||
    num_workers: 1
 | 
			
		||||
 | 
			
		||||
Eval:
 | 
			
		||||
  dataset:
 | 
			
		||||
@ -113,4 +113,4 @@ Eval:
 | 
			
		||||
    shuffle: False
 | 
			
		||||
    drop_last: False
 | 
			
		||||
    batch_size_per_card: 16
 | 
			
		||||
    num_workers: 4
 | 
			
		||||
    num_workers: 1
 | 
			
		||||
 | 
			
		||||
@ -412,7 +412,6 @@ class TableLabelEncode(object):
 | 
			
		||||
            return None
 | 
			
		||||
        elem_num = len(structure)
 | 
			
		||||
        structure = [0] + structure + [len(self.dict_elem) - 1]
 | 
			
		||||
#         structure = [0] + structure + [0]
 | 
			
		||||
        structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
 | 
			
		||||
        structure = np.array(structure)
 | 
			
		||||
        data['structure'] = structure
 | 
			
		||||
@ -443,8 +442,6 @@ class TableLabelEncode(object):
 | 
			
		||||
                if cand_span_idx < (self.max_elem_length + 2):
 | 
			
		||||
                    if structure[cand_span_idx] in span_idx_list:
 | 
			
		||||
                        structure_mask[cand_span_idx] = span_weight
 | 
			
		||||
#                         structure_mask[td_idx] = self.span_weight
 | 
			
		||||
#                         structure_mask[cand_span_idx] = self.span_weight
 | 
			
		||||
 | 
			
		||||
        data['bbox_list'] = bbox_list
 | 
			
		||||
        data['bbox_list_mask'] = bbox_list_mask
 | 
			
		||||
@ -458,23 +455,6 @@ class TableLabelEncode(object):
 | 
			
		||||
            self.max_elem_length, self.max_cell_num, elem_num])
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
        ########
 | 
			
		||||
        # for char decode
 | 
			
		||||
#         cell_list = []
 | 
			
		||||
#         for cell in cells:
 | 
			
		||||
#             char_list = cell['tokens']
 | 
			
		||||
#             cell = self.encode(char_list, 'char')
 | 
			
		||||
#             if cell is None:
 | 
			
		||||
#                 return None
 | 
			
		||||
#             cell = [0] + cell + [len(self.dict_character) - 1]
 | 
			
		||||
#             cell = cell + [0] * (self.max_text_length + 2 - len(cell))
 | 
			
		||||
#             cell_list.append(cell)
 | 
			
		||||
#         cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
 | 
			
		||||
#         cell_list = np.array(cell_list)
 | 
			
		||||
#         cell_list_padding[0:cell_list.shape[0]] = cell_list
 | 
			
		||||
#         data['cells'] = cell_list_padding
 | 
			
		||||
#         return data
 | 
			
		||||
 | 
			
		||||
    def encode(self, text, char_or_elem):
 | 
			
		||||
        """convert text-label into text-index.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
 | 
			
		||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
@ -19,6 +19,7 @@ import json
 | 
			
		||||
 | 
			
		||||
from .imaug import transform, create_operators
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PubTabDataSet(Dataset):
 | 
			
		||||
    def __init__(self, config, mode, logger, seed=None):
 | 
			
		||||
        super(PubTabDataSet, self).__init__()
 | 
			
		||||
@ -58,23 +59,6 @@ class PubTabDataSet(Dataset):
 | 
			
		||||
            random.shuffle(self.data_lines)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def load_hard_select_prob(self):
 | 
			
		||||
        label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
 | 
			
		||||
        img_select_prob = {}
 | 
			
		||||
        with open(label_path, "rb") as fin:
 | 
			
		||||
            lines = fin.readlines()
 | 
			
		||||
            for lno in range(len(lines)):
 | 
			
		||||
                substr = lines[lno].decode('utf-8').strip("\n").split(" ")
 | 
			
		||||
                img_name = substr[0].strip(":")
 | 
			
		||||
                score = float(substr[1])
 | 
			
		||||
                if score <= 0.8:
 | 
			
		||||
                    img_select_prob[img_name] = self.hard_prob[0]
 | 
			
		||||
                elif score <= 0.98:
 | 
			
		||||
                    img_select_prob[img_name] = self.hard_prob[1]
 | 
			
		||||
                else:
 | 
			
		||||
                    img_select_prob[img_name] = self.hard_prob[2]
 | 
			
		||||
        return img_select_prob
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        try:
 | 
			
		||||
            data_line = self.data_lines[idx]
 | 
			
		||||
@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
 | 
			
		||||
                table_type = "simple"
 | 
			
		||||
                if 'colspan' in structure_str or 'rowspan' in structure_str:
 | 
			
		||||
                    table_type = "complex"
 | 
			
		||||
#                 if self.table_select_type != table_type:
 | 
			
		||||
#                     select_flag = False
 | 
			
		||||
                if table_type == "complex":
 | 
			
		||||
                    if self.table_select_prob < random.uniform(0, 1):
 | 
			
		||||
                        select_flag = False                    
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
			
		||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
 | 
			
		||||
@ -21,13 +21,16 @@ import paddle.nn as nn
 | 
			
		||||
import paddle.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TableAttentionHead(nn.Layer):
 | 
			
		||||
    def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
 | 
			
		||||
        super(TableAttentionHead, self).__init__()
 | 
			
		||||
        self.input_size = in_channels[-1]
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.char_num = 280
 | 
			
		||||
        self.elem_num = 30
 | 
			
		||||
        self.max_text_length = 100
 | 
			
		||||
        self.max_elem_length = 500
 | 
			
		||||
        self.max_cell_num = 500
 | 
			
		||||
 | 
			
		||||
        self.structure_attention_cell = AttentionGRUCell(
 | 
			
		||||
            self.input_size, hidden_size, self.elem_num, use_gru=False)
 | 
			
		||||
@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
 | 
			
		||||
            self.loc_generator = nn.Linear(hidden_size, 4)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.in_max_len == 640:
 | 
			
		||||
                self.loc_fea_trans = nn.Linear(400, 801)
 | 
			
		||||
                self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
 | 
			
		||||
            elif self.in_max_len == 800:
 | 
			
		||||
                self.loc_fea_trans = nn.Linear(625, 801)
 | 
			
		||||
                self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
 | 
			
		||||
            else:
 | 
			
		||||
                self.loc_fea_trans = nn.Linear(256, 801)
 | 
			
		||||
                self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
 | 
			
		||||
            self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
 | 
			
		||||
            
 | 
			
		||||
    def _char_to_onehot(self, input_char, onehot_dim):
 | 
			
		||||
@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
 | 
			
		||||
            fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
 | 
			
		||||
            fea = fea.transpose([0, 2, 1])  # (NTC)(batch, width, channels)
 | 
			
		||||
        batch_size = fea.shape[0]
 | 
			
		||||
        #sp_tokens = targets[2].numpy()
 | 
			
		||||
        #char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
 | 
			
		||||
        #elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
 | 
			
		||||
        #elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
 | 
			
		||||
        #max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
 | 
			
		||||
        max_text_length, max_elem_length, max_cell_num = 100, 800, 500
 | 
			
		||||
        
 | 
			
		||||
        hidden = paddle.zeros((batch_size, self.hidden_size))
 | 
			
		||||
        output_hiddens = []
 | 
			
		||||
        if mode == 'Train' and targets is not None:
 | 
			
		||||
            structure = targets[0]
 | 
			
		||||
            for i in range(max_elem_length+1):
 | 
			
		||||
            for i in range(self.max_elem_length+1):
 | 
			
		||||
                elem_onehots = self._char_to_onehot(
 | 
			
		||||
                    structure[:, i], onehot_dim=self.elem_num)
 | 
			
		||||
                (outputs, hidden), alpha = self.structure_attention_cell(
 | 
			
		||||
@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
 | 
			
		||||
            elem_onehots = None
 | 
			
		||||
            outputs = None
 | 
			
		||||
            alpha = None
 | 
			
		||||
            max_elem_length = paddle.to_tensor(max_elem_length)
 | 
			
		||||
            max_elem_length = paddle.to_tensor(self.max_elem_length)
 | 
			
		||||
            i = 0
 | 
			
		||||
            while i < max_elem_length+1:
 | 
			
		||||
                elem_onehots = self._char_to_onehot(
 | 
			
		||||
@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
 | 
			
		||||
                loc_preds = F.sigmoid(loc_preds)
 | 
			
		||||
        return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
class AttentionGRUCell(nn.Layer):
 | 
			
		||||
    def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
 | 
			
		||||
        super(AttentionGRUCell, self).__init__()
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
 | 
			
		||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
 | 
			
		||||
            in_channels=in_channels[0],
 | 
			
		||||
            out_channels=self.out_channels,
 | 
			
		||||
            kernel_size=1,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name='conv2d_51.w_0', initializer=weight_attr),
 | 
			
		||||
            weight_attr=ParamAttr(initializer=weight_attr),
 | 
			
		||||
            bias_attr=False)
 | 
			
		||||
        self.in3_conv = nn.Conv2D(
 | 
			
		||||
            in_channels=in_channels[1],
 | 
			
		||||
            out_channels=self.out_channels,
 | 
			
		||||
            kernel_size=1,
 | 
			
		||||
            stride = 1,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name='conv2d_50.w_0', initializer=weight_attr),
 | 
			
		||||
            weight_attr=ParamAttr(initializer=weight_attr),
 | 
			
		||||
            bias_attr=False)
 | 
			
		||||
        self.in4_conv = nn.Conv2D(
 | 
			
		||||
            in_channels=in_channels[2],
 | 
			
		||||
            out_channels=self.out_channels,
 | 
			
		||||
            kernel_size=1,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name='conv2d_49.w_0', initializer=weight_attr),
 | 
			
		||||
            weight_attr=ParamAttr(initializer=weight_attr),
 | 
			
		||||
            bias_attr=False)
 | 
			
		||||
        self.in5_conv = nn.Conv2D(
 | 
			
		||||
            in_channels=in_channels[3],
 | 
			
		||||
            out_channels=self.out_channels,
 | 
			
		||||
            kernel_size=1,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name='conv2d_48.w_0', initializer=weight_attr),
 | 
			
		||||
            weight_attr=ParamAttr(initializer=weight_attr),
 | 
			
		||||
            bias_attr=False)
 | 
			
		||||
        self.p5_conv = nn.Conv2D(
 | 
			
		||||
            in_channels=self.out_channels,
 | 
			
		||||
            out_channels=self.out_channels // 4,
 | 
			
		||||
            kernel_size=3,
 | 
			
		||||
            padding=1,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name='conv2d_52.w_0', initializer=weight_attr),
 | 
			
		||||
            weight_attr=ParamAttr(initializer=weight_attr),
 | 
			
		||||
            bias_attr=False)
 | 
			
		||||
        self.p4_conv = nn.Conv2D(
 | 
			
		||||
            in_channels=self.out_channels,
 | 
			
		||||
            out_channels=self.out_channels // 4,
 | 
			
		||||
            kernel_size=3,
 | 
			
		||||
            padding=1,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name='conv2d_53.w_0', initializer=weight_attr),
 | 
			
		||||
            weight_attr=ParamAttr(initializer=weight_attr),
 | 
			
		||||
            bias_attr=False)
 | 
			
		||||
        self.p3_conv = nn.Conv2D(
 | 
			
		||||
            in_channels=self.out_channels,
 | 
			
		||||
            out_channels=self.out_channels // 4,
 | 
			
		||||
            kernel_size=3,
 | 
			
		||||
            padding=1,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name='conv2d_54.w_0', initializer=weight_attr),
 | 
			
		||||
            weight_attr=ParamAttr(initializer=weight_attr),
 | 
			
		||||
            bias_attr=False)
 | 
			
		||||
        self.p2_conv = nn.Conv2D(
 | 
			
		||||
            in_channels=self.out_channels,
 | 
			
		||||
            out_channels=self.out_channels // 4,
 | 
			
		||||
            kernel_size=3,
 | 
			
		||||
            padding=1,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name='conv2d_55.w_0', initializer=weight_attr),
 | 
			
		||||
            weight_attr=ParamAttr(initializer=weight_attr),
 | 
			
		||||
            bias_attr=False)
 | 
			
		||||
        self.fuse_conv = nn.Conv2D(
 | 
			
		||||
            in_channels=self.out_channels * 4,
 | 
			
		||||
            out_channels=512,
 | 
			
		||||
            kernel_size=3,
 | 
			
		||||
            padding=1,
 | 
			
		||||
            weight_attr=ParamAttr(
 | 
			
		||||
                name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False)
 | 
			
		||||
            weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        c2, c3, c4, c5 = x
 | 
			
		||||
 | 
			
		||||
@ -369,18 +369,6 @@ class TableLabelDecode(object):
 | 
			
		||||
        list_character = [self.beg_str] + list_character + [self.end_str]
 | 
			
		||||
        return list_character
 | 
			
		||||
    
 | 
			
		||||
    def get_sp_tokens(self):
 | 
			
		||||
        char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
 | 
			
		||||
        char_end_idx = self.get_beg_end_flag_idx('end', 'char')
 | 
			
		||||
        elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
 | 
			
		||||
        elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
 | 
			
		||||
        elem_char_idx1 = self.dict_elem['<td>']
 | 
			
		||||
        elem_char_idx2 = self.dict_elem['<td']
 | 
			
		||||
        sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx, 
 | 
			
		||||
            elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length, 
 | 
			
		||||
            self.max_elem_length, self.max_cell_num])
 | 
			
		||||
        return sp_tokens
 | 
			
		||||
    
 | 
			
		||||
    def __call__(self, preds):
 | 
			
		||||
        structure_probs = preds['structure_probs']
 | 
			
		||||
        loc_preds = preds['loc_preds']
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
 | 
			
		||||
                    "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
 | 
			
		||||
                )
 | 
			
		||||
                infer_shape[-1] = 100
 | 
			
		||||
 | 
			
		||||
        elif arch_config["model_type"] == "table":
 | 
			
		||||
            infer_shape = [3, 488, 488]
 | 
			
		||||
        model = to_static(
 | 
			
		||||
            model,
 | 
			
		||||
            input_spec=[
 | 
			
		||||
 | 
			
		||||
@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
 | 
			
		||||
            img = f.read()
 | 
			
		||||
            data = {'image': img}
 | 
			
		||||
        batch = transform(data, ops)
 | 
			
		||||
        sp_tokens = post_process_class.get_sp_tokens()
 | 
			
		||||
        targets = [[], [], paddle.to_tensor([sp_tokens])]
 | 
			
		||||
        images = np.expand_dims(batch[0], axis=0)
 | 
			
		||||
        images = paddle.to_tensor(images)
 | 
			
		||||
        preds = model(images, data=targets, mode='Test')
 | 
			
		||||
        preds = model(images, data=None, mode='Test')
 | 
			
		||||
        post_result = post_process_class(preds)
 | 
			
		||||
        res_html_code = post_result['res_html_code']
 | 
			
		||||
        res_loc = post_result['res_loc']
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
			
		||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
@ -276,6 +276,7 @@ def train(config,
 | 
			
		||||
                    valid_dataloader,
 | 
			
		||||
                    post_process_class,
 | 
			
		||||
                    eval_class,
 | 
			
		||||
                    "table",
 | 
			
		||||
                    use_srn=use_srn)
 | 
			
		||||
                cur_metric_str = 'cur metric, {}'.format(', '.join(
 | 
			
		||||
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user