mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-31 09:49:30 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			791 lines
		
	
	
		
			29 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			791 lines
		
	
	
		
			29 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #    http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import numpy as np
 | |
| import paddle
 | |
| from paddle.nn import functional as F
 | |
| import re
 | |
| 
 | |
| 
 | |
| class BaseRecLabelDecode(object):
 | |
|     """ Convert between text-label and text-index """
 | |
| 
 | |
|     def __init__(self, character_dict_path=None, use_space_char=False):
 | |
|         self.beg_str = "sos"
 | |
|         self.end_str = "eos"
 | |
| 
 | |
|         self.character_str = []
 | |
|         if character_dict_path is None:
 | |
|             self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
 | |
|             dict_character = list(self.character_str)
 | |
|         else:
 | |
|             with open(character_dict_path, "rb") as fin:
 | |
|                 lines = fin.readlines()
 | |
|                 for line in lines:
 | |
|                     line = line.decode('utf-8').strip("\n").strip("\r\n")
 | |
|                     self.character_str.append(line)
 | |
|             if use_space_char:
 | |
|                 self.character_str.append(" ")
 | |
|             dict_character = list(self.character_str)
 | |
| 
 | |
|         dict_character = self.add_special_char(dict_character)
 | |
|         self.dict = {}
 | |
|         for i, char in enumerate(dict_character):
 | |
|             self.dict[char] = i
 | |
|         self.character = dict_character
 | |
| 
 | |
|     def add_special_char(self, dict_character):
 | |
|         return dict_character
 | |
| 
 | |
|     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
 | |
|         """ convert text-index into text-label. """
 | |
|         result_list = []
 | |
|         ignored_tokens = self.get_ignored_tokens()
 | |
|         batch_size = len(text_index)
 | |
|         for batch_idx in range(batch_size):
 | |
|             selection = np.ones(len(text_index[batch_idx]), dtype=bool)
 | |
|             if is_remove_duplicate:
 | |
|                 selection[1:] = text_index[batch_idx][1:] != text_index[
 | |
|                     batch_idx][:-1]
 | |
|             for ignored_token in ignored_tokens:
 | |
|                 selection &= text_index[batch_idx] != ignored_token
 | |
| 
 | |
|             char_list = [
 | |
|                 self.character[text_id]
 | |
|                 for text_id in text_index[batch_idx][selection]
 | |
|             ]
 | |
|             if text_prob is not None:
 | |
|                 conf_list = text_prob[batch_idx][selection]
 | |
|             else:
 | |
|                 conf_list = [1] * len(selection)
 | |
|             if len(conf_list) == 0:
 | |
|                 conf_list = [0]
 | |
| 
 | |
|             text = ''.join(char_list)
 | |
|             result_list.append((text, np.mean(conf_list).tolist()))
 | |
|         return result_list
 | |
| 
 | |
|     def get_ignored_tokens(self):
 | |
|         return [0]  # for ctc blank
 | |
| 
 | |
| 
 | |
| class CTCLabelDecode(BaseRecLabelDecode):
 | |
|     """ Convert between text-label and text-index """
 | |
| 
 | |
|     def __init__(self, character_dict_path=None, use_space_char=False,
 | |
|                  **kwargs):
 | |
|         super(CTCLabelDecode, self).__init__(character_dict_path,
 | |
|                                              use_space_char)
 | |
| 
 | |
|     def __call__(self, preds, label=None, *args, **kwargs):
 | |
|         if isinstance(preds, tuple) or isinstance(preds, list):
 | |
|             preds = preds[-1]
 | |
|         if isinstance(preds, paddle.Tensor):
 | |
|             preds = preds.numpy()
 | |
|         preds_idx = preds.argmax(axis=2)
 | |
|         preds_prob = preds.max(axis=2)
 | |
|         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
 | |
|         if label is None:
 | |
|             return text
 | |
|         label = self.decode(label)
 | |
|         return text, label
 | |
| 
 | |
|     def add_special_char(self, dict_character):
 | |
|         dict_character = ['blank'] + dict_character
 | |
|         return dict_character
 | |
| 
 | |
| 
 | |
| class DistillationCTCLabelDecode(CTCLabelDecode):
 | |
|     """
 | |
|     Convert 
 | |
|     Convert between text-label and text-index
 | |
|     """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  character_dict_path=None,
 | |
|                  use_space_char=False,
 | |
|                  model_name=["student"],
 | |
|                  key=None,
 | |
|                  multi_head=False,
 | |
|                  **kwargs):
 | |
|         super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
 | |
|                                                          use_space_char)
 | |
|         if not isinstance(model_name, list):
 | |
|             model_name = [model_name]
 | |
|         self.model_name = model_name
 | |
| 
 | |
|         self.key = key
 | |
|         self.multi_head = multi_head
 | |
| 
 | |
|     def __call__(self, preds, label=None, *args, **kwargs):
 | |
|         output = dict()
 | |
|         for name in self.model_name:
 | |
|             pred = preds[name]
 | |
|             if self.key is not None:
 | |
|                 pred = pred[self.key]
 | |
|             if self.multi_head and isinstance(pred, dict):
 | |
|                 pred = pred['ctc']
 | |
|             output[name] = super().__call__(pred, label=label, *args, **kwargs)
 | |
|         return output
 | |
| 
 | |
| 
 | |
| class NRTRLabelDecode(BaseRecLabelDecode):
 | |
|     """ Convert between text-label and text-index """
 | |
| 
 | |
|     def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
 | |
|         super(NRTRLabelDecode, self).__init__(character_dict_path,
 | |
|                                               use_space_char)
 | |
| 
 | |
|     def __call__(self, preds, label=None, *args, **kwargs):
 | |
| 
 | |
|         if len(preds) == 2:
 | |
|             preds_id = preds[0]
 | |
|             preds_prob = preds[1]
 | |
|             if isinstance(preds_id, paddle.Tensor):
 | |
|                 preds_id = preds_id.numpy()
 | |
|             if isinstance(preds_prob, paddle.Tensor):
 | |
|                 preds_prob = preds_prob.numpy()
 | |
|             if preds_id[0][0] == 2:
 | |
|                 preds_idx = preds_id[:, 1:]
 | |
|                 preds_prob = preds_prob[:, 1:]
 | |
|             else:
 | |
|                 preds_idx = preds_id
 | |
|             text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
 | |
|             if label is None:
 | |
|                 return text
 | |
|             label = self.decode(label[:, 1:])
 | |
|         else:
 | |
|             if isinstance(preds, paddle.Tensor):
 | |
|                 preds = preds.numpy()
 | |
|             preds_idx = preds.argmax(axis=2)
 | |
|             preds_prob = preds.max(axis=2)
 | |
|             text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
 | |
|             if label is None:
 | |
|                 return text
 | |
|             label = self.decode(label[:, 1:])
 | |
|         return text, label
 | |
| 
 | |
|     def add_special_char(self, dict_character):
 | |
|         dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
 | |
|         return dict_character
 | |
| 
 | |
|     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
 | |
|         """ convert text-index into text-label. """
 | |
|         result_list = []
 | |
|         batch_size = len(text_index)
 | |
|         for batch_idx in range(batch_size):
 | |
|             char_list = []
 | |
|             conf_list = []
 | |
|             for idx in range(len(text_index[batch_idx])):
 | |
|                 if text_index[batch_idx][idx] == 3:  # end
 | |
|                     break
 | |
|                 try:
 | |
|                     char_list.append(self.character[int(text_index[batch_idx][
 | |
|                         idx])])
 | |
|                 except:
 | |
|                     continue
 | |
|                 if text_prob is not None:
 | |
|                     conf_list.append(text_prob[batch_idx][idx])
 | |
|                 else:
 | |
|                     conf_list.append(1)
 | |
|             text = ''.join(char_list)
 | |
|             result_list.append((text.lower(), np.mean(conf_list).tolist()))
 | |
|         return result_list
 | |
| 
 | |
| 
 | |
| class AttnLabelDecode(BaseRecLabelDecode):
 | |
|     """ Convert between text-label and text-index """
 | |
| 
 | |
|     def __init__(self, character_dict_path=None, use_space_char=False,
 | |
|                  **kwargs):
 | |
|         super(AttnLabelDecode, self).__init__(character_dict_path,
 | |
|                                               use_space_char)
 | |
| 
 | |
|     def add_special_char(self, dict_character):
 | |
|         self.beg_str = "sos"
 | |
|         self.end_str = "eos"
 | |
|         dict_character = dict_character
 | |
|         dict_character = [self.beg_str] + dict_character + [self.end_str]
 | |
|         return dict_character
 | |
| 
 | |
|     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
 | |
|         """ convert text-index into text-label. """
 | |
|         result_list = []
 | |
|         ignored_tokens = self.get_ignored_tokens()
 | |
|         [beg_idx, end_idx] = self.get_ignored_tokens()
 | |
|         batch_size = len(text_index)
 | |
|         for batch_idx in range(batch_size):
 | |
|             char_list = []
 | |
|             conf_list = []
 | |
|             for idx in range(len(text_index[batch_idx])):
 | |
|                 if text_index[batch_idx][idx] in ignored_tokens:
 | |
|                     continue
 | |
|                 if int(text_index[batch_idx][idx]) == int(end_idx):
 | |
|                     break
 | |
|                 if is_remove_duplicate:
 | |
|                     # only for predict
 | |
|                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
 | |
|                             batch_idx][idx]:
 | |
|                         continue
 | |
|                 char_list.append(self.character[int(text_index[batch_idx][
 | |
|                     idx])])
 | |
|                 if text_prob is not None:
 | |
|                     conf_list.append(text_prob[batch_idx][idx])
 | |
|                 else:
 | |
|                     conf_list.append(1)
 | |
|             text = ''.join(char_list)
 | |
|             result_list.append((text, np.mean(conf_list).tolist()))
 | |
|         return result_list
 | |
| 
 | |
|     def __call__(self, preds, label=None, *args, **kwargs):
 | |
|         """
 | |
|         text = self.decode(text)
 | |
|         if label is None:
 | |
|             return text
 | |
|         else:
 | |
|             label = self.decode(label, is_remove_duplicate=False)
 | |
|             return text, label
 | |
|         """
 | |
|         if isinstance(preds, paddle.Tensor):
 | |
|             preds = preds.numpy()
 | |
| 
 | |
|         preds_idx = preds.argmax(axis=2)
 | |
|         preds_prob = preds.max(axis=2)
 | |
|         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
 | |
|         if label is None:
 | |
|             return text
 | |
|         label = self.decode(label, is_remove_duplicate=False)
 | |
|         return text, label
 | |
| 
 | |
|     def get_ignored_tokens(self):
 | |
|         beg_idx = self.get_beg_end_flag_idx("beg")
 | |
|         end_idx = self.get_beg_end_flag_idx("end")
 | |
|         return [beg_idx, end_idx]
 | |
| 
 | |
|     def get_beg_end_flag_idx(self, beg_or_end):
 | |
|         if beg_or_end == "beg":
 | |
|             idx = np.array(self.dict[self.beg_str])
 | |
|         elif beg_or_end == "end":
 | |
|             idx = np.array(self.dict[self.end_str])
 | |
|         else:
 | |
|             assert False, "unsupport type %s in get_beg_end_flag_idx" \
 | |
|                           % beg_or_end
 | |
|         return idx
 | |
| 
 | |
| 
 | |
| class SEEDLabelDecode(BaseRecLabelDecode):
 | |
|     """ Convert between text-label and text-index """
 | |
| 
 | |
|     def __init__(self, character_dict_path=None, use_space_char=False,
 | |
|                  **kwargs):
 | |
|         super(SEEDLabelDecode, self).__init__(character_dict_path,
 | |
|                                               use_space_char)
 | |
| 
 | |
|     def add_special_char(self, dict_character):
 | |
|         self.padding_str = "padding"
 | |
|         self.end_str = "eos"
 | |
|         self.unknown = "unknown"
 | |
|         dict_character = dict_character + [
 | |
|             self.end_str, self.padding_str, self.unknown
 | |
|         ]
 | |
|         return dict_character
 | |
| 
 | |
|     def get_ignored_tokens(self):
 | |
|         end_idx = self.get_beg_end_flag_idx("eos")
 | |
|         return [end_idx]
 | |
| 
 | |
|     def get_beg_end_flag_idx(self, beg_or_end):
 | |
|         if beg_or_end == "sos":
 | |
|             idx = np.array(self.dict[self.beg_str])
 | |
|         elif beg_or_end == "eos":
 | |
|             idx = np.array(self.dict[self.end_str])
 | |
|         else:
 | |
|             assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
 | |
|         return idx
 | |
| 
 | |
|     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
 | |
|         """ convert text-index into text-label. """
 | |
|         result_list = []
 | |
|         [end_idx] = self.get_ignored_tokens()
 | |
|         batch_size = len(text_index)
 | |
|         for batch_idx in range(batch_size):
 | |
|             char_list = []
 | |
|             conf_list = []
 | |
|             for idx in range(len(text_index[batch_idx])):
 | |
|                 if int(text_index[batch_idx][idx]) == int(end_idx):
 | |
|                     break
 | |
|                 if is_remove_duplicate:
 | |
|                     # only for predict
 | |
|                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
 | |
|                             batch_idx][idx]:
 | |
|                         continue
 | |
|                 char_list.append(self.character[int(text_index[batch_idx][
 | |
|                     idx])])
 | |
|                 if text_prob is not None:
 | |
|                     conf_list.append(text_prob[batch_idx][idx])
 | |
|                 else:
 | |
|                     conf_list.append(1)
 | |
|             text = ''.join(char_list)
 | |
|             result_list.append((text, np.mean(conf_list).tolist()))
 | |
|         return result_list
 | |
| 
 | |
|     def __call__(self, preds, label=None, *args, **kwargs):
 | |
|         """
 | |
|         text = self.decode(text)
 | |
|         if label is None:
 | |
|             return text
 | |
|         else:
 | |
|             label = self.decode(label, is_remove_duplicate=False)
 | |
|             return text, label
 | |
|         """
 | |
|         preds_idx = preds["rec_pred"]
 | |
|         if isinstance(preds_idx, paddle.Tensor):
 | |
|             preds_idx = preds_idx.numpy()
 | |
|         if "rec_pred_scores" in preds:
 | |
|             preds_idx = preds["rec_pred"]
 | |
|             preds_prob = preds["rec_pred_scores"]
 | |
|         else:
 | |
|             preds_idx = preds["rec_pred"].argmax(axis=2)
 | |
|             preds_prob = preds["rec_pred"].max(axis=2)
 | |
|         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
 | |
|         if label is None:
 | |
|             return text
 | |
|         label = self.decode(label, is_remove_duplicate=False)
 | |
|         return text, label
 | |
| 
 | |
| 
 | |
| class SRNLabelDecode(BaseRecLabelDecode):
 | |
|     """ Convert between text-label and text-index """
 | |
| 
 | |
|     def __init__(self, character_dict_path=None, use_space_char=False,
 | |
|                  **kwargs):
 | |
|         super(SRNLabelDecode, self).__init__(character_dict_path,
 | |
|                                              use_space_char)
 | |
|         self.max_text_length = kwargs.get('max_text_length', 25)
 | |
| 
 | |
|     def __call__(self, preds, label=None, *args, **kwargs):
 | |
|         pred = preds['predict']
 | |
|         char_num = len(self.character_str) + 2
 | |
|         if isinstance(pred, paddle.Tensor):
 | |
|             pred = pred.numpy()
 | |
|         pred = np.reshape(pred, [-1, char_num])
 | |
| 
 | |
|         preds_idx = np.argmax(pred, axis=1)
 | |
|         preds_prob = np.max(pred, axis=1)
 | |
| 
 | |
|         preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
 | |
| 
 | |
|         preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
 | |
| 
 | |
|         text = self.decode(preds_idx, preds_prob)
 | |
| 
 | |
|         if label is None:
 | |
|             text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
 | |
|             return text
 | |
|         label = self.decode(label)
 | |
|         return text, label
 | |
| 
 | |
|     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
 | |
|         """ convert text-index into text-label. """
 | |
|         result_list = []
 | |
|         ignored_tokens = self.get_ignored_tokens()
 | |
|         batch_size = len(text_index)
 | |
| 
 | |
|         for batch_idx in range(batch_size):
 | |
|             char_list = []
 | |
|             conf_list = []
 | |
|             for idx in range(len(text_index[batch_idx])):
 | |
|                 if text_index[batch_idx][idx] in ignored_tokens:
 | |
|                     continue
 | |
|                 if is_remove_duplicate:
 | |
|                     # only for predict
 | |
|                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
 | |
|                             batch_idx][idx]:
 | |
|                         continue
 | |
|                 char_list.append(self.character[int(text_index[batch_idx][
 | |
|                     idx])])
 | |
|                 if text_prob is not None:
 | |
|                     conf_list.append(text_prob[batch_idx][idx])
 | |
|                 else:
 | |
|                     conf_list.append(1)
 | |
| 
 | |
|             text = ''.join(char_list)
 | |
|             result_list.append((text, np.mean(conf_list).tolist()))
 | |
|         return result_list
 | |
| 
 | |
|     def add_special_char(self, dict_character):
 | |
|         dict_character = dict_character + [self.beg_str, self.end_str]
 | |
|         return dict_character
 | |
| 
 | |
|     def get_ignored_tokens(self):
 | |
|         beg_idx = self.get_beg_end_flag_idx("beg")
 | |
|         end_idx = self.get_beg_end_flag_idx("end")
 | |
|         return [beg_idx, end_idx]
 | |
| 
 | |
|     def get_beg_end_flag_idx(self, beg_or_end):
 | |
|         if beg_or_end == "beg":
 | |
|             idx = np.array(self.dict[self.beg_str])
 | |
|         elif beg_or_end == "end":
 | |
|             idx = np.array(self.dict[self.end_str])
 | |
|         else:
 | |
|             assert False, "unsupport type %s in get_beg_end_flag_idx" \
 | |
|                           % beg_or_end
 | |
|         return idx
 | |
| 
 | |
| 
 | |
| class TableLabelDecode(object):
 | |
|     """  """
 | |
| 
 | |
|     def __init__(self, character_dict_path, **kwargs):
 | |
|         list_character, list_elem = self.load_char_elem_dict(
 | |
|             character_dict_path)
 | |
|         list_character = self.add_special_char(list_character)
 | |
|         list_elem = self.add_special_char(list_elem)
 | |
|         self.dict_character = {}
 | |
|         self.dict_idx_character = {}
 | |
|         for i, char in enumerate(list_character):
 | |
|             self.dict_idx_character[i] = char
 | |
|             self.dict_character[char] = i
 | |
|         self.dict_elem = {}
 | |
|         self.dict_idx_elem = {}
 | |
|         for i, elem in enumerate(list_elem):
 | |
|             self.dict_idx_elem[i] = elem
 | |
|             self.dict_elem[elem] = i
 | |
| 
 | |
|     def load_char_elem_dict(self, character_dict_path):
 | |
|         list_character = []
 | |
|         list_elem = []
 | |
|         with open(character_dict_path, "rb") as fin:
 | |
|             lines = fin.readlines()
 | |
|             substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
 | |
|                 "\t")
 | |
|             character_num = int(substr[0])
 | |
|             elem_num = int(substr[1])
 | |
|             for cno in range(1, 1 + character_num):
 | |
|                 character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
 | |
|                 list_character.append(character)
 | |
|             for eno in range(1 + character_num, 1 + character_num + elem_num):
 | |
|                 elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
 | |
|                 list_elem.append(elem)
 | |
|         return list_character, list_elem
 | |
| 
 | |
|     def add_special_char(self, list_character):
 | |
|         self.beg_str = "sos"
 | |
|         self.end_str = "eos"
 | |
|         list_character = [self.beg_str] + list_character + [self.end_str]
 | |
|         return list_character
 | |
| 
 | |
|     def __call__(self, preds):
 | |
|         structure_probs = preds['structure_probs']
 | |
|         loc_preds = preds['loc_preds']
 | |
|         if isinstance(structure_probs, paddle.Tensor):
 | |
|             structure_probs = structure_probs.numpy()
 | |
|         if isinstance(loc_preds, paddle.Tensor):
 | |
|             loc_preds = loc_preds.numpy()
 | |
|         structure_idx = structure_probs.argmax(axis=2)
 | |
|         structure_probs = structure_probs.max(axis=2)
 | |
|         structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
 | |
|             structure_idx, structure_probs, 'elem')
 | |
|         res_html_code_list = []
 | |
|         res_loc_list = []
 | |
|         batch_num = len(structure_str)
 | |
|         for bno in range(batch_num):
 | |
|             res_loc = []
 | |
|             for sno in range(len(structure_str[bno])):
 | |
|                 text = structure_str[bno][sno]
 | |
|                 if text in ['<td>', '<td']:
 | |
|                     pos = structure_pos[bno][sno]
 | |
|                     res_loc.append(loc_preds[bno, pos])
 | |
|             res_html_code = ''.join(structure_str[bno])
 | |
|             res_loc = np.array(res_loc)
 | |
|             res_html_code_list.append(res_html_code)
 | |
|             res_loc_list.append(res_loc)
 | |
|         return {
 | |
|             'res_html_code': res_html_code_list,
 | |
|             'res_loc': res_loc_list,
 | |
|             'res_score_list': result_score_list,
 | |
|             'res_elem_idx_list': result_elem_idx_list,
 | |
|             'structure_str_list': structure_str
 | |
|         }
 | |
| 
 | |
|     def decode(self, text_index, structure_probs, char_or_elem):
 | |
|         """convert text-label into text-index.
 | |
|         """
 | |
|         if char_or_elem == "char":
 | |
|             current_dict = self.dict_idx_character
 | |
|         else:
 | |
|             current_dict = self.dict_idx_elem
 | |
|             ignored_tokens = self.get_ignored_tokens('elem')
 | |
|             beg_idx, end_idx = ignored_tokens
 | |
| 
 | |
|         result_list = []
 | |
|         result_pos_list = []
 | |
|         result_score_list = []
 | |
|         result_elem_idx_list = []
 | |
|         batch_size = len(text_index)
 | |
|         for batch_idx in range(batch_size):
 | |
|             char_list = []
 | |
|             elem_pos_list = []
 | |
|             elem_idx_list = []
 | |
|             score_list = []
 | |
|             for idx in range(len(text_index[batch_idx])):
 | |
|                 tmp_elem_idx = int(text_index[batch_idx][idx])
 | |
|                 if idx > 0 and tmp_elem_idx == end_idx:
 | |
|                     break
 | |
|                 if tmp_elem_idx in ignored_tokens:
 | |
|                     continue
 | |
| 
 | |
|                 char_list.append(current_dict[tmp_elem_idx])
 | |
|                 elem_pos_list.append(idx)
 | |
|                 score_list.append(structure_probs[batch_idx, idx])
 | |
|                 elem_idx_list.append(tmp_elem_idx)
 | |
|             result_list.append(char_list)
 | |
|             result_pos_list.append(elem_pos_list)
 | |
|             result_score_list.append(score_list)
 | |
|             result_elem_idx_list.append(elem_idx_list)
 | |
|         return result_list, result_pos_list, result_score_list, result_elem_idx_list
 | |
| 
 | |
|     def get_ignored_tokens(self, char_or_elem):
 | |
|         beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
 | |
|         end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
 | |
|         return [beg_idx, end_idx]
 | |
| 
 | |
|     def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
 | |
|         if char_or_elem == "char":
 | |
|             if beg_or_end == "beg":
 | |
|                 idx = self.dict_character[self.beg_str]
 | |
|             elif beg_or_end == "end":
 | |
|                 idx = self.dict_character[self.end_str]
 | |
|             else:
 | |
|                 assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
 | |
|                               % beg_or_end
 | |
|         elif char_or_elem == "elem":
 | |
|             if beg_or_end == "beg":
 | |
|                 idx = self.dict_elem[self.beg_str]
 | |
|             elif beg_or_end == "end":
 | |
|                 idx = self.dict_elem[self.end_str]
 | |
|             else:
 | |
|                 assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
 | |
|                               % beg_or_end
 | |
|         else:
 | |
|             assert False, "Unsupport type %s in char_or_elem" \
 | |
|                           % char_or_elem
 | |
|         return idx
 | |
| 
 | |
| 
 | |
| class SARLabelDecode(BaseRecLabelDecode):
 | |
|     """ Convert between text-label and text-index """
 | |
| 
 | |
|     def __init__(self, character_dict_path=None, use_space_char=False,
 | |
|                  **kwargs):
 | |
|         super(SARLabelDecode, self).__init__(character_dict_path,
 | |
|                                              use_space_char)
 | |
| 
 | |
|         self.rm_symbol = kwargs.get('rm_symbol', False)
 | |
| 
 | |
|     def add_special_char(self, dict_character):
 | |
|         beg_end_str = "<BOS/EOS>"
 | |
|         unknown_str = "<UKN>"
 | |
|         padding_str = "<PAD>"
 | |
|         dict_character = dict_character + [unknown_str]
 | |
|         self.unknown_idx = len(dict_character) - 1
 | |
|         dict_character = dict_character + [beg_end_str]
 | |
|         self.start_idx = len(dict_character) - 1
 | |
|         self.end_idx = len(dict_character) - 1
 | |
|         dict_character = dict_character + [padding_str]
 | |
|         self.padding_idx = len(dict_character) - 1
 | |
|         return dict_character
 | |
| 
 | |
|     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
 | |
|         """ convert text-index into text-label. """
 | |
|         result_list = []
 | |
|         ignored_tokens = self.get_ignored_tokens()
 | |
| 
 | |
|         batch_size = len(text_index)
 | |
|         for batch_idx in range(batch_size):
 | |
|             char_list = []
 | |
|             conf_list = []
 | |
|             for idx in range(len(text_index[batch_idx])):
 | |
|                 if text_index[batch_idx][idx] in ignored_tokens:
 | |
|                     continue
 | |
|                 if int(text_index[batch_idx][idx]) == int(self.end_idx):
 | |
|                     if text_prob is None and idx == 0:
 | |
|                         continue
 | |
|                     else:
 | |
|                         break
 | |
|                 if is_remove_duplicate:
 | |
|                     # only for predict
 | |
|                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
 | |
|                             batch_idx][idx]:
 | |
|                         continue
 | |
|                 char_list.append(self.character[int(text_index[batch_idx][
 | |
|                     idx])])
 | |
|                 if text_prob is not None:
 | |
|                     conf_list.append(text_prob[batch_idx][idx])
 | |
|                 else:
 | |
|                     conf_list.append(1)
 | |
|             text = ''.join(char_list)
 | |
|             if self.rm_symbol:
 | |
|                 comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
 | |
|                 text = text.lower()
 | |
|                 text = comp.sub('', text)
 | |
|             result_list.append((text, np.mean(conf_list).tolist()))
 | |
|         return result_list
 | |
| 
 | |
|     def __call__(self, preds, label=None, *args, **kwargs):
 | |
|         if isinstance(preds, paddle.Tensor):
 | |
|             preds = preds.numpy()
 | |
|         preds_idx = preds.argmax(axis=2)
 | |
|         preds_prob = preds.max(axis=2)
 | |
| 
 | |
|         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
 | |
| 
 | |
|         if label is None:
 | |
|             return text
 | |
|         label = self.decode(label, is_remove_duplicate=False)
 | |
|         return text, label
 | |
| 
 | |
|     def get_ignored_tokens(self):
 | |
|         return [self.padding_idx]
 | |
| 
 | |
| 
 | |
| class DistillationSARLabelDecode(SARLabelDecode):
 | |
|     """
 | |
|     Convert 
 | |
|     Convert between text-label and text-index
 | |
|     """
 | |
| 
 | |
|     def __init__(self,
 | |
|                  character_dict_path=None,
 | |
|                  use_space_char=False,
 | |
|                  model_name=["student"],
 | |
|                  key=None,
 | |
|                  multi_head=False,
 | |
|                  **kwargs):
 | |
|         super(DistillationSARLabelDecode, self).__init__(character_dict_path,
 | |
|                                                          use_space_char)
 | |
|         if not isinstance(model_name, list):
 | |
|             model_name = [model_name]
 | |
|         self.model_name = model_name
 | |
| 
 | |
|         self.key = key
 | |
|         self.multi_head = multi_head
 | |
| 
 | |
|     def __call__(self, preds, label=None, *args, **kwargs):
 | |
|         output = dict()
 | |
|         for name in self.model_name:
 | |
|             pred = preds[name]
 | |
|             if self.key is not None:
 | |
|                 pred = pred[self.key]
 | |
|             if self.multi_head and isinstance(pred, dict):
 | |
|                 pred = pred['sar']
 | |
|             output[name] = super().__call__(pred, label=label, *args, **kwargs)
 | |
|         return output
 | |
| 
 | |
| 
 | |
| class PRENLabelDecode(BaseRecLabelDecode):
 | |
|     """ Convert between text-label and text-index """
 | |
| 
 | |
|     def __init__(self, character_dict_path=None, use_space_char=False,
 | |
|                  **kwargs):
 | |
|         super(PRENLabelDecode, self).__init__(character_dict_path,
 | |
|                                               use_space_char)
 | |
| 
 | |
|     def add_special_char(self, dict_character):
 | |
|         padding_str = '<PAD>'  # 0 
 | |
|         end_str = '<EOS>'  # 1
 | |
|         unknown_str = '<UNK>'  # 2
 | |
| 
 | |
|         dict_character = [padding_str, end_str, unknown_str] + dict_character
 | |
|         self.padding_idx = 0
 | |
|         self.end_idx = 1
 | |
|         self.unknown_idx = 2
 | |
| 
 | |
|         return dict_character
 | |
| 
 | |
|     def decode(self, text_index, text_prob=None):
 | |
|         """ convert text-index into text-label. """
 | |
|         result_list = []
 | |
|         batch_size = len(text_index)
 | |
| 
 | |
|         for batch_idx in range(batch_size):
 | |
|             char_list = []
 | |
|             conf_list = []
 | |
|             for idx in range(len(text_index[batch_idx])):
 | |
|                 if text_index[batch_idx][idx] == self.end_idx:
 | |
|                     break
 | |
|                 if text_index[batch_idx][idx] in \
 | |
|                     [self.padding_idx, self.unknown_idx]:
 | |
|                     continue
 | |
|                 char_list.append(self.character[int(text_index[batch_idx][
 | |
|                     idx])])
 | |
|                 if text_prob is not None:
 | |
|                     conf_list.append(text_prob[batch_idx][idx])
 | |
|                 else:
 | |
|                     conf_list.append(1)
 | |
| 
 | |
|             text = ''.join(char_list)
 | |
|             if len(text) > 0:
 | |
|                 result_list.append((text, np.mean(conf_list).tolist()))
 | |
|             else:
 | |
|                 # here confidence of empty recog result is 1
 | |
|                 result_list.append(('', 1))
 | |
|         return result_list
 | |
| 
 | |
|     def __call__(self, preds, label=None, *args, **kwargs):
 | |
|         preds = preds.numpy()
 | |
|         preds_idx = preds.argmax(axis=2)
 | |
|         preds_prob = preds.max(axis=2)
 | |
|         text = self.decode(preds_idx, preds_prob)
 | |
|         if label is None:
 | |
|             return text
 | |
|         label = self.decode(label)
 | |
|         return text, label
 | |
| 
 | |
| 
 | |
| class SVTRLabelDecode(BaseRecLabelDecode):
 | |
|     """ Convert between text-label and text-index """
 | |
| 
 | |
|     def __init__(self, character_dict_path=None, use_space_char=False,
 | |
|                  **kwargs):
 | |
|         super(SVTRLabelDecode, self).__init__(character_dict_path,
 | |
|                                              use_space_char)
 | |
| 
 | |
|     def __call__(self, preds, label=None, *args, **kwargs):
 | |
|         if isinstance(preds, tuple):
 | |
|             preds = preds[-1]
 | |
|         if isinstance(preds, paddle.Tensor):
 | |
|             preds = preds.numpy()
 | |
|         preds_idx = preds.argmax(axis=-1)
 | |
|         preds_prob = preds.max(axis=-1)
 | |
| 
 | |
|         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
 | |
|         return_text = []
 | |
|         for i in range(0, len(text), 3):
 | |
|             text0 = text[i]
 | |
|             text1 = text[i + 1]
 | |
|             text2 = text[i + 2]
 | |
| 
 | |
|             text_pred = [text0[0], text1[0], text2[0]]
 | |
|             text_prob = [text0[1], text1[1], text2[1]]
 | |
|             id_max = text_prob.index(max(text_prob))
 | |
|             return_text.append((text_pred[id_max], text_prob[id_max]))
 | |
|         if label is None:
 | |
|             return return_text
 | |
|         label = self.decode(label)
 | |
|         return return_text, label
 | |
| 
 | |
|     def add_special_char(self, dict_character):
 | |
|         dict_character = ['blank'] + dict_character
 | |
|         return dict_character | 
