mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-11-04 03:39:22 +00:00 
			
		
		
		
	refine
This commit is contained in:
		
							parent
							
								
									7c8b2c8d19
								
							
						
					
					
						commit
						16c247ac46
					
				@ -1,13 +1,12 @@
 | 
				
			|||||||
Global:
 | 
					Global:
 | 
				
			||||||
  use_gpu: true
 | 
					  use_gpu: true
 | 
				
			||||||
  epoch_num: 40
 | 
					  epoch_num: 50
 | 
				
			||||||
  log_smooth_window: 20
 | 
					  log_smooth_window: 20
 | 
				
			||||||
  print_batch_step: 5
 | 
					  print_batch_step: 5
 | 
				
			||||||
  save_model_dir: ./output/table_mv3/
 | 
					  save_model_dir: ./output/table_mv3/
 | 
				
			||||||
  save_epoch_step: 3
 | 
					  save_epoch_step: 5
 | 
				
			||||||
  # evaluation is run every 5000 iterations after the 4000th iteration
 | 
					  # evaluation is run every 400 iterations after the 0th iteration
 | 
				
			||||||
  eval_batch_step: [0, 400]
 | 
					  eval_batch_step: [0, 400]
 | 
				
			||||||
  # if pretrained_model is saved in static mode, load_static_weights must set to True
 | 
					 | 
				
			||||||
  cal_metric_during_train: True
 | 
					  cal_metric_during_train: True
 | 
				
			||||||
  pretrained_model: 
 | 
					  pretrained_model: 
 | 
				
			||||||
  checkpoints: 
 | 
					  checkpoints: 
 | 
				
			||||||
@ -18,19 +17,20 @@ Global:
 | 
				
			|||||||
  character_dict_path: ppocr/utils/dict/table_structure_dict.txt
 | 
					  character_dict_path: ppocr/utils/dict/table_structure_dict.txt
 | 
				
			||||||
  character_type: en
 | 
					  character_type: en
 | 
				
			||||||
  max_text_length: 100
 | 
					  max_text_length: 100
 | 
				
			||||||
  max_elem_length: 800
 | 
					  max_elem_length: 500
 | 
				
			||||||
  max_cell_num: 500
 | 
					  max_cell_num: 500
 | 
				
			||||||
  infer_mode: False
 | 
					  infer_mode: False
 | 
				
			||||||
  process_total_num: 0
 | 
					  process_total_num: 0
 | 
				
			||||||
  process_cut_num: 0
 | 
					  process_cut_num: 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Optimizer:
 | 
					Optimizer:
 | 
				
			||||||
  name: Adam
 | 
					  name: Adam
 | 
				
			||||||
  beta1: 0.9
 | 
					  beta1: 0.9
 | 
				
			||||||
  beta2: 0.999
 | 
					  beta2: 0.999
 | 
				
			||||||
  clip_norm: 5.0
 | 
					  clip_norm: 5.0
 | 
				
			||||||
  lr:
 | 
					  lr:
 | 
				
			||||||
    learning_rate: 0.0001
 | 
					    learning_rate: 0.001
 | 
				
			||||||
  regularizer:
 | 
					  regularizer:
 | 
				
			||||||
    name: 'L2'
 | 
					    name: 'L2'
 | 
				
			||||||
    factor: 0.00000
 | 
					    factor: 0.00000
 | 
				
			||||||
@ -41,12 +41,12 @@ Architecture:
 | 
				
			|||||||
  Backbone:
 | 
					  Backbone:
 | 
				
			||||||
    name: MobileNetV3
 | 
					    name: MobileNetV3
 | 
				
			||||||
    scale: 1.0
 | 
					    scale: 1.0
 | 
				
			||||||
    model_name: large
 | 
					    model_name: small
 | 
				
			||||||
 | 
					    disable_se: True
 | 
				
			||||||
  Head:
 | 
					  Head:
 | 
				
			||||||
    name: TableAttentionHead  # AttentionHead
 | 
					    name: TableAttentionHead
 | 
				
			||||||
    hidden_size: 256 #
 | 
					    hidden_size: 256
 | 
				
			||||||
    l2_decay: 0.00001
 | 
					    l2_decay: 0.00001
 | 
				
			||||||
#     loc_type: 1
 | 
					 | 
				
			||||||
    loc_type: 2
 | 
					    loc_type: 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Loss:
 | 
					Loss:
 | 
				
			||||||
@ -86,7 +86,7 @@ Train:
 | 
				
			|||||||
    shuffle: True
 | 
					    shuffle: True
 | 
				
			||||||
    batch_size_per_card: 32
 | 
					    batch_size_per_card: 32
 | 
				
			||||||
    drop_last: True
 | 
					    drop_last: True
 | 
				
			||||||
    num_workers: 4
 | 
					    num_workers: 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Eval:
 | 
					Eval:
 | 
				
			||||||
  dataset:
 | 
					  dataset:
 | 
				
			||||||
@ -113,4 +113,4 @@ Eval:
 | 
				
			|||||||
    shuffle: False
 | 
					    shuffle: False
 | 
				
			||||||
    drop_last: False
 | 
					    drop_last: False
 | 
				
			||||||
    batch_size_per_card: 16
 | 
					    batch_size_per_card: 16
 | 
				
			||||||
    num_workers: 4
 | 
					    num_workers: 1
 | 
				
			||||||
 | 
				
			|||||||
@ -412,7 +412,6 @@ class TableLabelEncode(object):
 | 
				
			|||||||
            return None
 | 
					            return None
 | 
				
			||||||
        elem_num = len(structure)
 | 
					        elem_num = len(structure)
 | 
				
			||||||
        structure = [0] + structure + [len(self.dict_elem) - 1]
 | 
					        structure = [0] + structure + [len(self.dict_elem) - 1]
 | 
				
			||||||
#         structure = [0] + structure + [0]
 | 
					 | 
				
			||||||
        structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
 | 
					        structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
 | 
				
			||||||
        structure = np.array(structure)
 | 
					        structure = np.array(structure)
 | 
				
			||||||
        data['structure'] = structure
 | 
					        data['structure'] = structure
 | 
				
			||||||
@ -443,8 +442,6 @@ class TableLabelEncode(object):
 | 
				
			|||||||
                if cand_span_idx < (self.max_elem_length + 2):
 | 
					                if cand_span_idx < (self.max_elem_length + 2):
 | 
				
			||||||
                    if structure[cand_span_idx] in span_idx_list:
 | 
					                    if structure[cand_span_idx] in span_idx_list:
 | 
				
			||||||
                        structure_mask[cand_span_idx] = span_weight
 | 
					                        structure_mask[cand_span_idx] = span_weight
 | 
				
			||||||
#                         structure_mask[td_idx] = self.span_weight
 | 
					 | 
				
			||||||
#                         structure_mask[cand_span_idx] = self.span_weight
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        data['bbox_list'] = bbox_list
 | 
					        data['bbox_list'] = bbox_list
 | 
				
			||||||
        data['bbox_list_mask'] = bbox_list_mask
 | 
					        data['bbox_list_mask'] = bbox_list_mask
 | 
				
			||||||
@ -458,23 +455,6 @@ class TableLabelEncode(object):
 | 
				
			|||||||
            self.max_elem_length, self.max_cell_num, elem_num])
 | 
					            self.max_elem_length, self.max_cell_num, elem_num])
 | 
				
			||||||
        return data
 | 
					        return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ########
 | 
					 | 
				
			||||||
        # for char decode
 | 
					 | 
				
			||||||
#         cell_list = []
 | 
					 | 
				
			||||||
#         for cell in cells:
 | 
					 | 
				
			||||||
#             char_list = cell['tokens']
 | 
					 | 
				
			||||||
#             cell = self.encode(char_list, 'char')
 | 
					 | 
				
			||||||
#             if cell is None:
 | 
					 | 
				
			||||||
#                 return None
 | 
					 | 
				
			||||||
#             cell = [0] + cell + [len(self.dict_character) - 1]
 | 
					 | 
				
			||||||
#             cell = cell + [0] * (self.max_text_length + 2 - len(cell))
 | 
					 | 
				
			||||||
#             cell_list.append(cell)
 | 
					 | 
				
			||||||
#         cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
 | 
					 | 
				
			||||||
#         cell_list = np.array(cell_list)
 | 
					 | 
				
			||||||
#         cell_list_padding[0:cell_list.shape[0]] = cell_list
 | 
					 | 
				
			||||||
#         data['cells'] = cell_list_padding
 | 
					 | 
				
			||||||
#         return data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def encode(self, text, char_or_elem):
 | 
					    def encode(self, text, char_or_elem):
 | 
				
			||||||
        """convert text-label into text-index.
 | 
					        """convert text-label into text-index.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
 | 
					# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
@ -19,6 +19,7 @@ import json
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from .imaug import transform, create_operators
 | 
					from .imaug import transform, create_operators
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PubTabDataSet(Dataset):
 | 
					class PubTabDataSet(Dataset):
 | 
				
			||||||
    def __init__(self, config, mode, logger, seed=None):
 | 
					    def __init__(self, config, mode, logger, seed=None):
 | 
				
			||||||
        super(PubTabDataSet, self).__init__()
 | 
					        super(PubTabDataSet, self).__init__()
 | 
				
			||||||
@ -58,23 +59,6 @@ class PubTabDataSet(Dataset):
 | 
				
			|||||||
            random.shuffle(self.data_lines)
 | 
					            random.shuffle(self.data_lines)
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load_hard_select_prob(self):
 | 
					 | 
				
			||||||
        label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
 | 
					 | 
				
			||||||
        img_select_prob = {}
 | 
					 | 
				
			||||||
        with open(label_path, "rb") as fin:
 | 
					 | 
				
			||||||
            lines = fin.readlines()
 | 
					 | 
				
			||||||
            for lno in range(len(lines)):
 | 
					 | 
				
			||||||
                substr = lines[lno].decode('utf-8').strip("\n").split(" ")
 | 
					 | 
				
			||||||
                img_name = substr[0].strip(":")
 | 
					 | 
				
			||||||
                score = float(substr[1])
 | 
					 | 
				
			||||||
                if score <= 0.8:
 | 
					 | 
				
			||||||
                    img_select_prob[img_name] = self.hard_prob[0]
 | 
					 | 
				
			||||||
                elif score <= 0.98:
 | 
					 | 
				
			||||||
                    img_select_prob[img_name] = self.hard_prob[1]
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    img_select_prob[img_name] = self.hard_prob[2]
 | 
					 | 
				
			||||||
        return img_select_prob
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __getitem__(self, idx):
 | 
					    def __getitem__(self, idx):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            data_line = self.data_lines[idx]
 | 
					            data_line = self.data_lines[idx]
 | 
				
			||||||
@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
 | 
				
			|||||||
                table_type = "simple"
 | 
					                table_type = "simple"
 | 
				
			||||||
                if 'colspan' in structure_str or 'rowspan' in structure_str:
 | 
					                if 'colspan' in structure_str or 'rowspan' in structure_str:
 | 
				
			||||||
                    table_type = "complex"
 | 
					                    table_type = "complex"
 | 
				
			||||||
#                 if self.table_select_type != table_type:
 | 
					 | 
				
			||||||
#                     select_flag = False
 | 
					 | 
				
			||||||
                if table_type == "complex":
 | 
					                if table_type == "complex":
 | 
				
			||||||
                    if self.table_select_prob < random.uniform(0, 1):
 | 
					                    if self.table_select_prob < random.uniform(0, 1):
 | 
				
			||||||
                        select_flag = False                    
 | 
					                        select_flag = False                    
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
					# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
				
			|||||||
@ -21,13 +21,16 @@ import paddle.nn as nn
 | 
				
			|||||||
import paddle.nn.functional as F
 | 
					import paddle.nn.functional as F
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TableAttentionHead(nn.Layer):
 | 
					class TableAttentionHead(nn.Layer):
 | 
				
			||||||
    def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
 | 
					    def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
 | 
				
			||||||
        super(TableAttentionHead, self).__init__()
 | 
					        super(TableAttentionHead, self).__init__()
 | 
				
			||||||
        self.input_size = in_channels[-1]
 | 
					        self.input_size = in_channels[-1]
 | 
				
			||||||
        self.hidden_size = hidden_size
 | 
					        self.hidden_size = hidden_size
 | 
				
			||||||
        self.char_num = 280
 | 
					 | 
				
			||||||
        self.elem_num = 30
 | 
					        self.elem_num = 30
 | 
				
			||||||
 | 
					        self.max_text_length = 100
 | 
				
			||||||
 | 
					        self.max_elem_length = 500
 | 
				
			||||||
 | 
					        self.max_cell_num = 500
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.structure_attention_cell = AttentionGRUCell(
 | 
					        self.structure_attention_cell = AttentionGRUCell(
 | 
				
			||||||
            self.input_size, hidden_size, self.elem_num, use_gru=False)
 | 
					            self.input_size, hidden_size, self.elem_num, use_gru=False)
 | 
				
			||||||
@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
 | 
				
			|||||||
            self.loc_generator = nn.Linear(hidden_size, 4)
 | 
					            self.loc_generator = nn.Linear(hidden_size, 4)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if self.in_max_len == 640:
 | 
					            if self.in_max_len == 640:
 | 
				
			||||||
                self.loc_fea_trans = nn.Linear(400, 801)
 | 
					                self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
 | 
				
			||||||
            elif self.in_max_len == 800:
 | 
					            elif self.in_max_len == 800:
 | 
				
			||||||
                self.loc_fea_trans = nn.Linear(625, 801)
 | 
					                self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                self.loc_fea_trans = nn.Linear(256, 801)
 | 
					                self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
 | 
				
			||||||
            self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
 | 
					            self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
    def _char_to_onehot(self, input_char, onehot_dim):
 | 
					    def _char_to_onehot(self, input_char, onehot_dim):
 | 
				
			||||||
@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
 | 
				
			|||||||
            fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
 | 
					            fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
 | 
				
			||||||
            fea = fea.transpose([0, 2, 1])  # (NTC)(batch, width, channels)
 | 
					            fea = fea.transpose([0, 2, 1])  # (NTC)(batch, width, channels)
 | 
				
			||||||
        batch_size = fea.shape[0]
 | 
					        batch_size = fea.shape[0]
 | 
				
			||||||
        #sp_tokens = targets[2].numpy()
 | 
					 | 
				
			||||||
        #char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
 | 
					 | 
				
			||||||
        #elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
 | 
					 | 
				
			||||||
        #elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
 | 
					 | 
				
			||||||
        #max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
 | 
					 | 
				
			||||||
        max_text_length, max_elem_length, max_cell_num = 100, 800, 500
 | 
					 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        hidden = paddle.zeros((batch_size, self.hidden_size))
 | 
					        hidden = paddle.zeros((batch_size, self.hidden_size))
 | 
				
			||||||
        output_hiddens = []
 | 
					        output_hiddens = []
 | 
				
			||||||
        if mode == 'Train' and targets is not None:
 | 
					        if mode == 'Train' and targets is not None:
 | 
				
			||||||
            structure = targets[0]
 | 
					            structure = targets[0]
 | 
				
			||||||
            for i in range(max_elem_length+1):
 | 
					            for i in range(self.max_elem_length+1):
 | 
				
			||||||
                elem_onehots = self._char_to_onehot(
 | 
					                elem_onehots = self._char_to_onehot(
 | 
				
			||||||
                    structure[:, i], onehot_dim=self.elem_num)
 | 
					                    structure[:, i], onehot_dim=self.elem_num)
 | 
				
			||||||
                (outputs, hidden), alpha = self.structure_attention_cell(
 | 
					                (outputs, hidden), alpha = self.structure_attention_cell(
 | 
				
			||||||
@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
 | 
				
			|||||||
            elem_onehots = None
 | 
					            elem_onehots = None
 | 
				
			||||||
            outputs = None
 | 
					            outputs = None
 | 
				
			||||||
            alpha = None
 | 
					            alpha = None
 | 
				
			||||||
            max_elem_length = paddle.to_tensor(max_elem_length)
 | 
					            max_elem_length = paddle.to_tensor(self.max_elem_length)
 | 
				
			||||||
            i = 0
 | 
					            i = 0
 | 
				
			||||||
            while i < max_elem_length+1:
 | 
					            while i < max_elem_length+1:
 | 
				
			||||||
                elem_onehots = self._char_to_onehot(
 | 
					                elem_onehots = self._char_to_onehot(
 | 
				
			||||||
@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
 | 
				
			|||||||
                loc_preds = F.sigmoid(loc_preds)
 | 
					                loc_preds = F.sigmoid(loc_preds)
 | 
				
			||||||
        return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
 | 
					        return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
class AttentionGRUCell(nn.Layer):
 | 
					class AttentionGRUCell(nn.Layer):
 | 
				
			||||||
    def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
 | 
					    def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
 | 
				
			||||||
        super(AttentionGRUCell, self).__init__()
 | 
					        super(AttentionGRUCell, self).__init__()
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
 | 
					# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
 | 
				
			|||||||
            in_channels=in_channels[0],
 | 
					            in_channels=in_channels[0],
 | 
				
			||||||
            out_channels=self.out_channels,
 | 
					            out_channels=self.out_channels,
 | 
				
			||||||
            kernel_size=1,
 | 
					            kernel_size=1,
 | 
				
			||||||
            weight_attr=ParamAttr(
 | 
					            weight_attr=ParamAttr(initializer=weight_attr),
 | 
				
			||||||
                name='conv2d_51.w_0', initializer=weight_attr),
 | 
					 | 
				
			||||||
            bias_attr=False)
 | 
					            bias_attr=False)
 | 
				
			||||||
        self.in3_conv = nn.Conv2D(
 | 
					        self.in3_conv = nn.Conv2D(
 | 
				
			||||||
            in_channels=in_channels[1],
 | 
					            in_channels=in_channels[1],
 | 
				
			||||||
            out_channels=self.out_channels,
 | 
					            out_channels=self.out_channels,
 | 
				
			||||||
            kernel_size=1,
 | 
					            kernel_size=1,
 | 
				
			||||||
            stride = 1,
 | 
					            stride = 1,
 | 
				
			||||||
            weight_attr=ParamAttr(
 | 
					            weight_attr=ParamAttr(initializer=weight_attr),
 | 
				
			||||||
                name='conv2d_50.w_0', initializer=weight_attr),
 | 
					 | 
				
			||||||
            bias_attr=False)
 | 
					            bias_attr=False)
 | 
				
			||||||
        self.in4_conv = nn.Conv2D(
 | 
					        self.in4_conv = nn.Conv2D(
 | 
				
			||||||
            in_channels=in_channels[2],
 | 
					            in_channels=in_channels[2],
 | 
				
			||||||
            out_channels=self.out_channels,
 | 
					            out_channels=self.out_channels,
 | 
				
			||||||
            kernel_size=1,
 | 
					            kernel_size=1,
 | 
				
			||||||
            weight_attr=ParamAttr(
 | 
					            weight_attr=ParamAttr(initializer=weight_attr),
 | 
				
			||||||
                name='conv2d_49.w_0', initializer=weight_attr),
 | 
					 | 
				
			||||||
            bias_attr=False)
 | 
					            bias_attr=False)
 | 
				
			||||||
        self.in5_conv = nn.Conv2D(
 | 
					        self.in5_conv = nn.Conv2D(
 | 
				
			||||||
            in_channels=in_channels[3],
 | 
					            in_channels=in_channels[3],
 | 
				
			||||||
            out_channels=self.out_channels,
 | 
					            out_channels=self.out_channels,
 | 
				
			||||||
            kernel_size=1,
 | 
					            kernel_size=1,
 | 
				
			||||||
            weight_attr=ParamAttr(
 | 
					            weight_attr=ParamAttr(initializer=weight_attr),
 | 
				
			||||||
                name='conv2d_48.w_0', initializer=weight_attr),
 | 
					 | 
				
			||||||
            bias_attr=False)
 | 
					            bias_attr=False)
 | 
				
			||||||
        self.p5_conv = nn.Conv2D(
 | 
					        self.p5_conv = nn.Conv2D(
 | 
				
			||||||
            in_channels=self.out_channels,
 | 
					            in_channels=self.out_channels,
 | 
				
			||||||
            out_channels=self.out_channels // 4,
 | 
					            out_channels=self.out_channels // 4,
 | 
				
			||||||
            kernel_size=3,
 | 
					            kernel_size=3,
 | 
				
			||||||
            padding=1,
 | 
					            padding=1,
 | 
				
			||||||
            weight_attr=ParamAttr(
 | 
					            weight_attr=ParamAttr(initializer=weight_attr),
 | 
				
			||||||
                name='conv2d_52.w_0', initializer=weight_attr),
 | 
					 | 
				
			||||||
            bias_attr=False)
 | 
					            bias_attr=False)
 | 
				
			||||||
        self.p4_conv = nn.Conv2D(
 | 
					        self.p4_conv = nn.Conv2D(
 | 
				
			||||||
            in_channels=self.out_channels,
 | 
					            in_channels=self.out_channels,
 | 
				
			||||||
            out_channels=self.out_channels // 4,
 | 
					            out_channels=self.out_channels // 4,
 | 
				
			||||||
            kernel_size=3,
 | 
					            kernel_size=3,
 | 
				
			||||||
            padding=1,
 | 
					            padding=1,
 | 
				
			||||||
            weight_attr=ParamAttr(
 | 
					            weight_attr=ParamAttr(initializer=weight_attr),
 | 
				
			||||||
                name='conv2d_53.w_0', initializer=weight_attr),
 | 
					 | 
				
			||||||
            bias_attr=False)
 | 
					            bias_attr=False)
 | 
				
			||||||
        self.p3_conv = nn.Conv2D(
 | 
					        self.p3_conv = nn.Conv2D(
 | 
				
			||||||
            in_channels=self.out_channels,
 | 
					            in_channels=self.out_channels,
 | 
				
			||||||
            out_channels=self.out_channels // 4,
 | 
					            out_channels=self.out_channels // 4,
 | 
				
			||||||
            kernel_size=3,
 | 
					            kernel_size=3,
 | 
				
			||||||
            padding=1,
 | 
					            padding=1,
 | 
				
			||||||
            weight_attr=ParamAttr(
 | 
					            weight_attr=ParamAttr(initializer=weight_attr),
 | 
				
			||||||
                name='conv2d_54.w_0', initializer=weight_attr),
 | 
					 | 
				
			||||||
            bias_attr=False)
 | 
					            bias_attr=False)
 | 
				
			||||||
        self.p2_conv = nn.Conv2D(
 | 
					        self.p2_conv = nn.Conv2D(
 | 
				
			||||||
            in_channels=self.out_channels,
 | 
					            in_channels=self.out_channels,
 | 
				
			||||||
            out_channels=self.out_channels // 4,
 | 
					            out_channels=self.out_channels // 4,
 | 
				
			||||||
            kernel_size=3,
 | 
					            kernel_size=3,
 | 
				
			||||||
            padding=1,
 | 
					            padding=1,
 | 
				
			||||||
            weight_attr=ParamAttr(
 | 
					            weight_attr=ParamAttr(initializer=weight_attr),
 | 
				
			||||||
                name='conv2d_55.w_0', initializer=weight_attr),
 | 
					 | 
				
			||||||
            bias_attr=False)
 | 
					            bias_attr=False)
 | 
				
			||||||
        self.fuse_conv = nn.Conv2D(
 | 
					        self.fuse_conv = nn.Conv2D(
 | 
				
			||||||
            in_channels=self.out_channels * 4,
 | 
					            in_channels=self.out_channels * 4,
 | 
				
			||||||
            out_channels=512,
 | 
					            out_channels=512,
 | 
				
			||||||
            kernel_size=3,
 | 
					            kernel_size=3,
 | 
				
			||||||
            padding=1,
 | 
					            padding=1,
 | 
				
			||||||
            weight_attr=ParamAttr(
 | 
					            weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
 | 
				
			||||||
                name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        c2, c3, c4, c5 = x
 | 
					        c2, c3, c4, c5 = x
 | 
				
			||||||
 | 
				
			|||||||
@ -369,18 +369,6 @@ class TableLabelDecode(object):
 | 
				
			|||||||
        list_character = [self.beg_str] + list_character + [self.end_str]
 | 
					        list_character = [self.beg_str] + list_character + [self.end_str]
 | 
				
			||||||
        return list_character
 | 
					        return list_character
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def get_sp_tokens(self):
 | 
					 | 
				
			||||||
        char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
 | 
					 | 
				
			||||||
        char_end_idx = self.get_beg_end_flag_idx('end', 'char')
 | 
					 | 
				
			||||||
        elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
 | 
					 | 
				
			||||||
        elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
 | 
					 | 
				
			||||||
        elem_char_idx1 = self.dict_elem['<td>']
 | 
					 | 
				
			||||||
        elem_char_idx2 = self.dict_elem['<td']
 | 
					 | 
				
			||||||
        sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx, 
 | 
					 | 
				
			||||||
            elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length, 
 | 
					 | 
				
			||||||
            self.max_elem_length, self.max_cell_num])
 | 
					 | 
				
			||||||
        return sp_tokens
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    def __call__(self, preds):
 | 
					    def __call__(self, preds):
 | 
				
			||||||
        structure_probs = preds['structure_probs']
 | 
					        structure_probs = preds['structure_probs']
 | 
				
			||||||
        loc_preds = preds['loc_preds']
 | 
					        loc_preds = preds['loc_preds']
 | 
				
			||||||
 | 
				
			|||||||
@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
 | 
				
			|||||||
                    "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
 | 
					                    "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                infer_shape[-1] = 100
 | 
					                infer_shape[-1] = 100
 | 
				
			||||||
 | 
					        elif arch_config["model_type"] == "table":
 | 
				
			||||||
 | 
					            infer_shape = [3, 488, 488]
 | 
				
			||||||
        model = to_static(
 | 
					        model = to_static(
 | 
				
			||||||
            model,
 | 
					            model,
 | 
				
			||||||
            input_spec=[
 | 
					            input_spec=[
 | 
				
			||||||
 | 
				
			|||||||
@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
 | 
				
			|||||||
            img = f.read()
 | 
					            img = f.read()
 | 
				
			||||||
            data = {'image': img}
 | 
					            data = {'image': img}
 | 
				
			||||||
        batch = transform(data, ops)
 | 
					        batch = transform(data, ops)
 | 
				
			||||||
        sp_tokens = post_process_class.get_sp_tokens()
 | 
					 | 
				
			||||||
        targets = [[], [], paddle.to_tensor([sp_tokens])]
 | 
					 | 
				
			||||||
        images = np.expand_dims(batch[0], axis=0)
 | 
					        images = np.expand_dims(batch[0], axis=0)
 | 
				
			||||||
        images = paddle.to_tensor(images)
 | 
					        images = paddle.to_tensor(images)
 | 
				
			||||||
        preds = model(images, data=targets, mode='Test')
 | 
					        preds = model(images, data=None, mode='Test')
 | 
				
			||||||
        post_result = post_process_class(preds)
 | 
					        post_result = post_process_class(preds)
 | 
				
			||||||
        res_html_code = post_result['res_html_code']
 | 
					        res_html_code = post_result['res_html_code']
 | 
				
			||||||
        res_loc = post_result['res_loc']
 | 
					        res_loc = post_result['res_loc']
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
					# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
@ -276,6 +276,7 @@ def train(config,
 | 
				
			|||||||
                    valid_dataloader,
 | 
					                    valid_dataloader,
 | 
				
			||||||
                    post_process_class,
 | 
					                    post_process_class,
 | 
				
			||||||
                    eval_class,
 | 
					                    eval_class,
 | 
				
			||||||
 | 
					                    "table",
 | 
				
			||||||
                    use_srn=use_srn)
 | 
					                    use_srn=use_srn)
 | 
				
			||||||
                cur_metric_str = 'cur metric, {}'.format(', '.join(
 | 
					                cur_metric_str = 'cur metric, {}'.format(', '.join(
 | 
				
			||||||
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
 | 
					                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user