mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-26 07:14:45 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			88 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			88 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import paddle
 | |
| import numpy as np
 | |
| import copy
 | |
| 
 | |
| 
 | |
| def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
 | |
|     """
 | |
|     """
 | |
|     pos_lists_, pos_masks_, label_lists_ = [], [], []
 | |
|     img_bs = batch_size
 | |
|     ngpu = int(batch_size / img_bs)
 | |
|     img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
 | |
|     pos_lists_split, pos_masks_split, label_lists_split = [], [], []
 | |
|     for i in range(ngpu):
 | |
|         pos_lists_split.append([])
 | |
|         pos_masks_split.append([])
 | |
|         label_lists_split.append([])
 | |
| 
 | |
|     for i in range(img_ids.shape[0]):
 | |
|         img_id = img_ids[i]
 | |
|         gpu_id = int(img_id / img_bs)
 | |
|         img_id = img_id % img_bs
 | |
|         pos_list = pos_lists[i].copy()
 | |
|         pos_list[:, 0] = img_id
 | |
|         pos_lists_split[gpu_id].append(pos_list)
 | |
|         pos_masks_split[gpu_id].append(pos_masks[i].copy())
 | |
|         label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
 | |
|     # repeat or delete
 | |
|     for i in range(ngpu):
 | |
|         vp_len = len(pos_lists_split[i])
 | |
|         if vp_len <= tcl_bs:
 | |
|             for j in range(0, tcl_bs - vp_len):
 | |
|                 pos_list = pos_lists_split[i][j].copy()
 | |
|                 pos_lists_split[i].append(pos_list)
 | |
|                 pos_mask = pos_masks_split[i][j].copy()
 | |
|                 pos_masks_split[i].append(pos_mask)
 | |
|                 label_list = copy.deepcopy(label_lists_split[i][j])
 | |
|                 label_lists_split[i].append(label_list)
 | |
|         else:
 | |
|             for j in range(0, vp_len - tcl_bs):
 | |
|                 c_len = len(pos_lists_split[i])
 | |
|                 pop_id = np.random.permutation(c_len)[0]
 | |
|                 pos_lists_split[i].pop(pop_id)
 | |
|                 pos_masks_split[i].pop(pop_id)
 | |
|                 label_lists_split[i].pop(pop_id)
 | |
|     # merge
 | |
|     for i in range(ngpu):
 | |
|         pos_lists_.extend(pos_lists_split[i])
 | |
|         pos_masks_.extend(pos_masks_split[i])
 | |
|         label_lists_.extend(label_lists_split[i])
 | |
|     return pos_lists_, pos_masks_, label_lists_
 | |
| 
 | |
| 
 | |
| def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
 | |
|                 pad_num, tcl_bs):
 | |
|     label_list = label_list.numpy()
 | |
|     batch, _, _, _ = label_list.shape
 | |
|     pos_list = pos_list.numpy()
 | |
|     pos_mask = pos_mask.numpy()
 | |
|     pos_list_t = []
 | |
|     pos_mask_t = []
 | |
|     label_list_t = []
 | |
|     for i in range(batch):
 | |
|         for j in range(max_text_nums):
 | |
|             if pos_mask[i, j].any():
 | |
|                 pos_list_t.append(pos_list[i][j])
 | |
|                 pos_mask_t.append(pos_mask[i][j])
 | |
|                 label_list_t.append(label_list[i][j])
 | |
|     pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
 | |
|                                                   label_list_t, tcl_bs)
 | |
|     label = []
 | |
|     tt = [l.tolist() for l in label_list]
 | |
|     for i in range(tcl_bs):
 | |
|         k = 0
 | |
|         for j in range(max_text_length):
 | |
|             if tt[i][j][0] != pad_num:
 | |
|                 k += 1
 | |
|             else:
 | |
|                 break
 | |
|         label.append(k)
 | |
|     label = paddle.to_tensor(label)
 | |
|     label = paddle.cast(label, dtype='int64')
 | |
|     pos_list = paddle.to_tensor(pos_list)
 | |
|     pos_mask = paddle.to_tensor(pos_mask)
 | |
|     label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
 | |
|     label_list = paddle.cast(label_list, dtype='int32')
 | |
|     return pos_list, pos_mask, label_list, label
 | 
