| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #    http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							| 
									
										
										
										
											2022-02-28 21:48:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | import numpy as np | 
					
						
							|  |  |  | import paddle | 
					
						
							|  |  |  | from paddle.nn import functional as F | 
					
						
							| 
									
										
										
										
											2021-08-24 03:49:26 +00:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class BaseRecLabelDecode(object): | 
					
						
							|  |  |  |     """ Convert between text-label and text-index """ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |     def __init__(self, character_dict_path=None, use_space_char=False): | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  |         self.beg_str = "sos" | 
					
						
							|  |  |  |         self.end_str = "eos" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |         self.character_str = [] | 
					
						
							|  |  |  |         if character_dict_path is None: | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |             self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" | 
					
						
							|  |  |  |             dict_character = list(self.character_str) | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |             with open(character_dict_path, "rb") as fin: | 
					
						
							|  |  |  |                 lines = fin.readlines() | 
					
						
							|  |  |  |                 for line in lines: | 
					
						
							|  |  |  |                     line = line.decode('utf-8').strip("\n").strip("\r\n") | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  |                     self.character_str.append(line) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |             if use_space_char: | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  |                 self.character_str.append(" ") | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |             dict_character = list(self.character_str) | 
					
						
							| 
									
										
										
										
											2021-01-26 15:24:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         dict_character = self.add_special_char(dict_character) | 
					
						
							|  |  |  |         self.dict = {} | 
					
						
							|  |  |  |         for i, char in enumerate(dict_character): | 
					
						
							|  |  |  |             self.dict[char] = i | 
					
						
							|  |  |  |         self.character = dict_character | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_special_char(self, dict_character): | 
					
						
							|  |  |  |         return dict_character | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-19 21:03:13 +08:00
										 |  |  |     def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         """ convert text-index into text-label. """ | 
					
						
							|  |  |  |         result_list = [] | 
					
						
							|  |  |  |         ignored_tokens = self.get_ignored_tokens() | 
					
						
							|  |  |  |         batch_size = len(text_index) | 
					
						
							|  |  |  |         for batch_idx in range(batch_size): | 
					
						
							| 
									
										
										
										
											2022-03-24 11:02:24 +08:00
										 |  |  |             selection = np.ones(len(text_index[batch_idx]), dtype=bool) | 
					
						
							|  |  |  |             if is_remove_duplicate: | 
					
						
							|  |  |  |                 selection[1:] = text_index[batch_idx][1:] != text_index[ | 
					
						
							|  |  |  |                     batch_idx][:-1] | 
					
						
							|  |  |  |             for ignored_token in ignored_tokens: | 
					
						
							|  |  |  |                 selection &= text_index[batch_idx] != ignored_token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             char_list = [ | 
					
						
							|  |  |  |                 self.character[text_id] | 
					
						
							|  |  |  |                 for text_id in text_index[batch_idx][selection] | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |             if text_prob is not None: | 
					
						
							|  |  |  |                 conf_list = text_prob[batch_idx][selection] | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 conf_list = [1] * len(selection) | 
					
						
							|  |  |  |             if len(conf_list) == 0: | 
					
						
							|  |  |  |                 conf_list = [0] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |             text = ''.join(char_list) | 
					
						
							| 
									
										
										
										
											2022-04-22 13:24:45 +08:00
										 |  |  |             result_list.append((text, np.mean(conf_list).tolist())) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         return result_list | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_ignored_tokens(self): | 
					
						
							|  |  |  |         return [0]  # for ctc blank | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CTCLabelDecode(BaseRecLabelDecode): | 
					
						
							|  |  |  |     """ Convert between text-label and text-index """ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |     def __init__(self, character_dict_path=None, use_space_char=False, | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(CTCLabelDecode, self).__init__(character_dict_path, | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |                                              use_space_char) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds, label=None, *args, **kwargs): | 
					
						
							| 
									
										
										
										
											2022-04-07 08:30:17 +00:00
										 |  |  |         if isinstance(preds, tuple) or isinstance(preds, list): | 
					
						
							| 
									
										
										
										
											2021-09-29 10:30:09 +08:00
										 |  |  |             preds = preds[-1] | 
					
						
							| 
									
										
										
										
											2020-11-09 18:19:42 +08:00
										 |  |  |         if isinstance(preds, paddle.Tensor): | 
					
						
							|  |  |  |             preds = preds.numpy() | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         preds_idx = preds.argmax(axis=2) | 
					
						
							|  |  |  |         preds_prob = preds.max(axis=2) | 
					
						
							| 
									
										
										
										
											2021-01-20 18:33:42 +08:00
										 |  |  |         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         if label is None: | 
					
						
							|  |  |  |             return text | 
					
						
							| 
									
										
										
										
											2021-01-19 21:03:13 +08:00
										 |  |  |         label = self.decode(label) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         return text, label | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_special_char(self, dict_character): | 
					
						
							|  |  |  |         dict_character = ['blank'] + dict_character | 
					
						
							|  |  |  |         return dict_character | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  | class DistillationCTCLabelDecode(CTCLabelDecode): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Convert  | 
					
						
							|  |  |  |     Convert between text-label and text-index | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  character_dict_path=None, | 
					
						
							|  |  |  |                  use_space_char=False, | 
					
						
							| 
									
										
										
										
											2021-06-03 13:31:25 +00:00
										 |  |  |                  model_name=["student"], | 
					
						
							| 
									
										
										
										
											2021-06-03 05:57:31 +00:00
										 |  |  |                  key=None, | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  |                  multi_head=False, | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  |                  **kwargs): | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |         super(DistillationCTCLabelDecode, self).__init__(character_dict_path, | 
					
						
							|  |  |  |                                                          use_space_char) | 
					
						
							| 
									
										
										
										
											2021-06-03 13:31:25 +00:00
										 |  |  |         if not isinstance(model_name, list): | 
					
						
							|  |  |  |             model_name = [model_name] | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  |         self.model_name = model_name | 
					
						
							| 
									
										
										
										
											2021-06-03 13:31:25 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-03 05:57:31 +00:00
										 |  |  |         self.key = key | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  |         self.multi_head = multi_head | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds, label=None, *args, **kwargs): | 
					
						
							| 
									
										
										
										
											2021-06-03 13:31:25 +00:00
										 |  |  |         output = dict() | 
					
						
							|  |  |  |         for name in self.model_name: | 
					
						
							|  |  |  |             pred = preds[name] | 
					
						
							|  |  |  |             if self.key is not None: | 
					
						
							|  |  |  |                 pred = pred[self.key] | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  |             if self.multi_head and isinstance(pred, dict): | 
					
						
							|  |  |  |                 pred = pred['ctc'] | 
					
						
							| 
									
										
										
										
											2021-06-03 13:31:25 +00:00
										 |  |  |             output[name] = super().__call__(pred, label=label, *args, **kwargs) | 
					
						
							|  |  |  |         return output | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  | class NRTRLabelDecode(BaseRecLabelDecode): | 
					
						
							|  |  |  |     """ Convert between text-label and text-index """ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |     def __init__(self, character_dict_path=None, use_space_char=True, **kwargs): | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  |         super(NRTRLabelDecode, self).__init__(character_dict_path, | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |                                               use_space_char) | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds, label=None, *args, **kwargs): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-13 13:10:10 +00:00
										 |  |  |         if len(preds) == 2: | 
					
						
							|  |  |  |             preds_id = preds[0] | 
					
						
							|  |  |  |             preds_prob = preds[1] | 
					
						
							|  |  |  |             if isinstance(preds_id, paddle.Tensor): | 
					
						
							|  |  |  |                 preds_id = preds_id.numpy() | 
					
						
							|  |  |  |             if isinstance(preds_prob, paddle.Tensor): | 
					
						
							|  |  |  |                 preds_prob = preds_prob.numpy() | 
					
						
							|  |  |  |             if preds_id[0][0] == 2: | 
					
						
							|  |  |  |                 preds_idx = preds_id[:, 1:] | 
					
						
							|  |  |  |                 preds_prob = preds_prob[:, 1:] | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 preds_idx = preds_id | 
					
						
							|  |  |  |             text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  |             if label is None: | 
					
						
							|  |  |  |                 return text | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |             label = self.decode(label[:, 1:]) | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  |         else: | 
					
						
							|  |  |  |             if isinstance(preds, paddle.Tensor): | 
					
						
							|  |  |  |                 preds = preds.numpy() | 
					
						
							|  |  |  |             preds_idx = preds.argmax(axis=2) | 
					
						
							|  |  |  |             preds_prob = preds.max(axis=2) | 
					
						
							|  |  |  |             text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) | 
					
						
							|  |  |  |             if label is None: | 
					
						
							|  |  |  |                 return text | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |             label = self.decode(label[:, 1:]) | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  |         return text, label | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_special_char(self, dict_character): | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |         dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  |         return dict_character | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  |     def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | 
					
						
							|  |  |  |         """ convert text-index into text-label. """ | 
					
						
							|  |  |  |         result_list = [] | 
					
						
							|  |  |  |         batch_size = len(text_index) | 
					
						
							|  |  |  |         for batch_idx in range(batch_size): | 
					
						
							|  |  |  |             char_list = [] | 
					
						
							|  |  |  |             conf_list = [] | 
					
						
							|  |  |  |             for idx in range(len(text_index[batch_idx])): | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |                 if text_index[batch_idx][idx] == 3:  # end | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  |                     break | 
					
						
							|  |  |  |                 try: | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |                     char_list.append(self.character[int(text_index[batch_idx][ | 
					
						
							|  |  |  |                         idx])]) | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  |                 except: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  |                 if text_prob is not None: | 
					
						
							|  |  |  |                     conf_list.append(text_prob[batch_idx][idx]) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     conf_list.append(1) | 
					
						
							|  |  |  |             text = ''.join(char_list) | 
					
						
							| 
									
										
										
										
											2022-04-22 13:24:45 +08:00
										 |  |  |             result_list.append((text.lower(), np.mean(conf_list).tolist())) | 
					
						
							| 
									
										
										
										
											2021-08-16 11:33:15 +00:00
										 |  |  |         return result_list | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | class AttnLabelDecode(BaseRecLabelDecode): | 
					
						
							|  |  |  |     """ Convert between text-label and text-index """ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |     def __init__(self, character_dict_path=None, use_space_char=False, | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(AttnLabelDecode, self).__init__(character_dict_path, | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |                                               use_space_char) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def add_special_char(self, dict_character): | 
					
						
							| 
									
										
										
										
											2021-01-29 03:15:03 +00:00
										 |  |  |         self.beg_str = "sos" | 
					
						
							|  |  |  |         self.end_str = "eos" | 
					
						
							|  |  |  |         dict_character = dict_character | 
					
						
							|  |  |  |         dict_character = [self.beg_str] + dict_character + [self.end_str] | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         return dict_character | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-01 06:27:56 +00:00
										 |  |  |     def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | 
					
						
							|  |  |  |         """ convert text-index into text-label. """ | 
					
						
							|  |  |  |         result_list = [] | 
					
						
							|  |  |  |         ignored_tokens = self.get_ignored_tokens() | 
					
						
							|  |  |  |         [beg_idx, end_idx] = self.get_ignored_tokens() | 
					
						
							|  |  |  |         batch_size = len(text_index) | 
					
						
							|  |  |  |         for batch_idx in range(batch_size): | 
					
						
							|  |  |  |             char_list = [] | 
					
						
							|  |  |  |             conf_list = [] | 
					
						
							|  |  |  |             for idx in range(len(text_index[batch_idx])): | 
					
						
							|  |  |  |                 if text_index[batch_idx][idx] in ignored_tokens: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  |                 if int(text_index[batch_idx][idx]) == int(end_idx): | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  |                 if is_remove_duplicate: | 
					
						
							|  |  |  |                     # only for predict | 
					
						
							|  |  |  |                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ | 
					
						
							|  |  |  |                             batch_idx][idx]: | 
					
						
							|  |  |  |                         continue | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |                 char_list.append(self.character[int(text_index[batch_idx][ | 
					
						
							|  |  |  |                     idx])]) | 
					
						
							| 
									
										
										
										
											2021-02-01 06:27:56 +00:00
										 |  |  |                 if text_prob is not None: | 
					
						
							|  |  |  |                     conf_list.append(text_prob[batch_idx][idx]) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     conf_list.append(1) | 
					
						
							|  |  |  |             text = ''.join(char_list) | 
					
						
							| 
									
										
										
										
											2022-04-22 13:24:45 +08:00
										 |  |  |             result_list.append((text, np.mean(conf_list).tolist())) | 
					
						
							| 
									
										
										
										
											2021-02-01 06:27:56 +00:00
										 |  |  |         return result_list | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-29 03:15:03 +00:00
										 |  |  |     def __call__(self, preds, label=None, *args, **kwargs): | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         text = self.decode(text) | 
					
						
							| 
									
										
										
										
											2021-01-29 03:15:03 +00:00
										 |  |  |         if label is None: | 
					
						
							|  |  |  |             return text | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             label = self.decode(label, is_remove_duplicate=False) | 
					
						
							|  |  |  |             return text, label | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if isinstance(preds, paddle.Tensor): | 
					
						
							|  |  |  |             preds = preds.numpy() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         preds_idx = preds.argmax(axis=2) | 
					
						
							|  |  |  |         preds_prob = preds.max(axis=2) | 
					
						
							| 
									
										
										
										
											2021-02-01 06:27:56 +00:00
										 |  |  |         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) | 
					
						
							| 
									
										
										
										
											2021-01-29 03:15:03 +00:00
										 |  |  |         if label is None: | 
					
						
							|  |  |  |             return text | 
					
						
							| 
									
										
										
										
											2021-02-01 06:27:56 +00:00
										 |  |  |         label = self.decode(label, is_remove_duplicate=False) | 
					
						
							| 
									
										
										
										
											2021-01-29 03:15:03 +00:00
										 |  |  |         return text, label | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     def get_ignored_tokens(self): | 
					
						
							|  |  |  |         beg_idx = self.get_beg_end_flag_idx("beg") | 
					
						
							|  |  |  |         end_idx = self.get_beg_end_flag_idx("end") | 
					
						
							|  |  |  |         return [beg_idx, end_idx] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_beg_end_flag_idx(self, beg_or_end): | 
					
						
							|  |  |  |         if beg_or_end == "beg": | 
					
						
							|  |  |  |             idx = np.array(self.dict[self.beg_str]) | 
					
						
							|  |  |  |         elif beg_or_end == "end": | 
					
						
							|  |  |  |             idx = np.array(self.dict[self.end_str]) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             assert False, "unsupport type %s in get_beg_end_flag_idx" \ | 
					
						
							|  |  |  |                           % beg_or_end | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         return idx | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-30 06:32:54 +00:00
										 |  |  | class SEEDLabelDecode(BaseRecLabelDecode): | 
					
						
							|  |  |  |     """ Convert between text-label and text-index """ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |     def __init__(self, character_dict_path=None, use_space_char=False, | 
					
						
							| 
									
										
										
										
											2021-08-30 06:32:54 +00:00
										 |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(SEEDLabelDecode, self).__init__(character_dict_path, | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |                                               use_space_char) | 
					
						
							| 
									
										
										
										
											2021-08-30 06:32:54 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def add_special_char(self, dict_character): | 
					
						
							| 
									
										
										
										
											2021-12-17 21:42:53 +08:00
										 |  |  |         self.padding_str = "padding" | 
					
						
							| 
									
										
										
										
											2021-08-30 06:32:54 +00:00
										 |  |  |         self.end_str = "eos" | 
					
						
							| 
									
										
										
										
											2021-12-17 21:42:53 +08:00
										 |  |  |         self.unknown = "unknown" | 
					
						
							|  |  |  |         dict_character = dict_character + [ | 
					
						
							|  |  |  |             self.end_str, self.padding_str, self.unknown | 
					
						
							|  |  |  |         ] | 
					
						
							| 
									
										
										
										
											2021-08-30 06:32:54 +00:00
										 |  |  |         return dict_character | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_ignored_tokens(self): | 
					
						
							|  |  |  |         end_idx = self.get_beg_end_flag_idx("eos") | 
					
						
							|  |  |  |         return [end_idx] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_beg_end_flag_idx(self, beg_or_end): | 
					
						
							|  |  |  |         if beg_or_end == "sos": | 
					
						
							|  |  |  |             idx = np.array(self.dict[self.beg_str]) | 
					
						
							|  |  |  |         elif beg_or_end == "eos": | 
					
						
							|  |  |  |             idx = np.array(self.dict[self.end_str]) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end | 
					
						
							|  |  |  |         return idx | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | 
					
						
							|  |  |  |         """ convert text-index into text-label. """ | 
					
						
							|  |  |  |         result_list = [] | 
					
						
							|  |  |  |         [end_idx] = self.get_ignored_tokens() | 
					
						
							|  |  |  |         batch_size = len(text_index) | 
					
						
							|  |  |  |         for batch_idx in range(batch_size): | 
					
						
							|  |  |  |             char_list = [] | 
					
						
							|  |  |  |             conf_list = [] | 
					
						
							|  |  |  |             for idx in range(len(text_index[batch_idx])): | 
					
						
							|  |  |  |                 if int(text_index[batch_idx][idx]) == int(end_idx): | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  |                 if is_remove_duplicate: | 
					
						
							|  |  |  |                     # only for predict | 
					
						
							|  |  |  |                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ | 
					
						
							|  |  |  |                             batch_idx][idx]: | 
					
						
							|  |  |  |                         continue | 
					
						
							|  |  |  |                 char_list.append(self.character[int(text_index[batch_idx][ | 
					
						
							|  |  |  |                     idx])]) | 
					
						
							|  |  |  |                 if text_prob is not None: | 
					
						
							|  |  |  |                     conf_list.append(text_prob[batch_idx][idx]) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     conf_list.append(1) | 
					
						
							|  |  |  |             text = ''.join(char_list) | 
					
						
							| 
									
										
										
										
											2022-04-22 13:24:45 +08:00
										 |  |  |             result_list.append((text, np.mean(conf_list).tolist())) | 
					
						
							| 
									
										
										
										
											2021-08-30 06:32:54 +00:00
										 |  |  |         return result_list | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds, label=None, *args, **kwargs): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         text = self.decode(text) | 
					
						
							|  |  |  |         if label is None: | 
					
						
							|  |  |  |             return text | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             label = self.decode(label, is_remove_duplicate=False) | 
					
						
							|  |  |  |             return text, label | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         preds_idx = preds["rec_pred"] | 
					
						
							|  |  |  |         if isinstance(preds_idx, paddle.Tensor): | 
					
						
							|  |  |  |             preds_idx = preds_idx.numpy() | 
					
						
							|  |  |  |         if "rec_pred_scores" in preds: | 
					
						
							|  |  |  |             preds_idx = preds["rec_pred"] | 
					
						
							|  |  |  |             preds_prob = preds["rec_pred_scores"] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             preds_idx = preds["rec_pred"].argmax(axis=2) | 
					
						
							|  |  |  |             preds_prob = preds["rec_pred"].max(axis=2) | 
					
						
							|  |  |  |         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) | 
					
						
							|  |  |  |         if label is None: | 
					
						
							|  |  |  |             return text | 
					
						
							|  |  |  |         label = self.decode(label, is_remove_duplicate=False) | 
					
						
							|  |  |  |         return text, label | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  | class SRNLabelDecode(BaseRecLabelDecode): | 
					
						
							|  |  |  |     """ Convert between text-label and text-index """ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |     def __init__(self, character_dict_path=None, use_space_char=False, | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(SRNLabelDecode, self).__init__(character_dict_path, | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |                                              use_space_char) | 
					
						
							| 
									
										
										
										
											2021-04-27 12:12:19 +08:00
										 |  |  |         self.max_text_length = kwargs.get('max_text_length', 25) | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds, label=None, *args, **kwargs): | 
					
						
							|  |  |  |         pred = preds['predict'] | 
					
						
							|  |  |  |         char_num = len(self.character_str) + 2 | 
					
						
							|  |  |  |         if isinstance(pred, paddle.Tensor): | 
					
						
							|  |  |  |             pred = pred.numpy() | 
					
						
							|  |  |  |         pred = np.reshape(pred, [-1, char_num]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         preds_idx = np.argmax(pred, axis=1) | 
					
						
							|  |  |  |         preds_prob = np.max(pred, axis=1) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-27 12:12:19 +08:00
										 |  |  |         preds_idx = np.reshape(preds_idx, [-1, self.max_text_length]) | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-27 12:12:19 +08:00
										 |  |  |         preds_prob = np.reshape(preds_prob, [-1, self.max_text_length]) | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-22 03:15:56 +00:00
										 |  |  |         text = self.decode(preds_idx, preds_prob) | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if label is None: | 
					
						
							| 
									
										
										
										
											2021-02-01 06:32:14 +00:00
										 |  |  |             text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  |             return text | 
					
						
							| 
									
										
										
										
											2021-01-22 03:15:56 +00:00
										 |  |  |         label = self.decode(label) | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  |         return text, label | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-01 06:32:14 +00:00
										 |  |  |     def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  |         """ convert text-index into text-label. """ | 
					
						
							|  |  |  |         result_list = [] | 
					
						
							|  |  |  |         ignored_tokens = self.get_ignored_tokens() | 
					
						
							|  |  |  |         batch_size = len(text_index) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for batch_idx in range(batch_size): | 
					
						
							|  |  |  |             char_list = [] | 
					
						
							|  |  |  |             conf_list = [] | 
					
						
							|  |  |  |             for idx in range(len(text_index[batch_idx])): | 
					
						
							|  |  |  |                 if text_index[batch_idx][idx] in ignored_tokens: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  |                 if is_remove_duplicate: | 
					
						
							|  |  |  |                     # only for predict | 
					
						
							|  |  |  |                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ | 
					
						
							|  |  |  |                             batch_idx][idx]: | 
					
						
							|  |  |  |                         continue | 
					
						
							|  |  |  |                 char_list.append(self.character[int(text_index[batch_idx][ | 
					
						
							|  |  |  |                     idx])]) | 
					
						
							|  |  |  |                 if text_prob is not None: | 
					
						
							|  |  |  |                     conf_list.append(text_prob[batch_idx][idx]) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     conf_list.append(1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             text = ''.join(char_list) | 
					
						
							| 
									
										
										
										
											2022-04-22 13:24:45 +08:00
										 |  |  |             result_list.append((text, np.mean(conf_list).tolist())) | 
					
						
							| 
									
										
										
										
											2020-12-30 16:15:49 +08:00
										 |  |  |         return result_list | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_special_char(self, dict_character): | 
					
						
							|  |  |  |         dict_character = dict_character + [self.beg_str, self.end_str] | 
					
						
							|  |  |  |         return dict_character | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_ignored_tokens(self): | 
					
						
							|  |  |  |         beg_idx = self.get_beg_end_flag_idx("beg") | 
					
						
							|  |  |  |         end_idx = self.get_beg_end_flag_idx("end") | 
					
						
							|  |  |  |         return [beg_idx, end_idx] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_beg_end_flag_idx(self, beg_or_end): | 
					
						
							|  |  |  |         if beg_or_end == "beg": | 
					
						
							|  |  |  |             idx = np.array(self.dict[self.beg_str]) | 
					
						
							|  |  |  |         elif beg_or_end == "end": | 
					
						
							|  |  |  |             idx = np.array(self.dict[self.end_str]) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             assert False, "unsupport type %s in get_beg_end_flag_idx" \ | 
					
						
							|  |  |  |                           % beg_or_end | 
					
						
							|  |  |  |         return idx | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TableLabelDecode(object): | 
					
						
							|  |  |  |     """  """ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |     def __init__(self, character_dict_path, **kwargs): | 
					
						
							|  |  |  |         list_character, list_elem = self.load_char_elem_dict( | 
					
						
							|  |  |  |             character_dict_path) | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  |         list_character = self.add_special_char(list_character) | 
					
						
							|  |  |  |         list_elem = self.add_special_char(list_elem) | 
					
						
							|  |  |  |         self.dict_character = {} | 
					
						
							|  |  |  |         self.dict_idx_character = {} | 
					
						
							|  |  |  |         for i, char in enumerate(list_character): | 
					
						
							|  |  |  |             self.dict_idx_character[i] = char | 
					
						
							|  |  |  |             self.dict_character[char] = i | 
					
						
							|  |  |  |         self.dict_elem = {} | 
					
						
							|  |  |  |         self.dict_idx_elem = {} | 
					
						
							|  |  |  |         for i, elem in enumerate(list_elem): | 
					
						
							|  |  |  |             self.dict_idx_elem[i] = elem | 
					
						
							|  |  |  |             self.dict_elem[elem] = i | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def load_char_elem_dict(self, character_dict_path): | 
					
						
							|  |  |  |         list_character = [] | 
					
						
							|  |  |  |         list_elem = [] | 
					
						
							|  |  |  |         with open(character_dict_path, "rb") as fin: | 
					
						
							|  |  |  |             lines = fin.readlines() | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |             substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split( | 
					
						
							|  |  |  |                 "\t") | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  |             character_num = int(substr[0]) | 
					
						
							|  |  |  |             elem_num = int(substr[1]) | 
					
						
							|  |  |  |             for cno in range(1, 1 + character_num): | 
					
						
							| 
									
										
										
										
											2021-07-30 13:31:46 +08:00
										 |  |  |                 character = lines[cno].decode('utf-8').strip("\n").strip("\r\n") | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  |                 list_character.append(character) | 
					
						
							|  |  |  |             for eno in range(1 + character_num, 1 + character_num + elem_num): | 
					
						
							| 
									
										
										
										
											2021-07-30 13:31:46 +08:00
										 |  |  |                 elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n") | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  |                 list_elem.append(elem) | 
					
						
							|  |  |  |         return list_character, list_elem | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_special_char(self, list_character): | 
					
						
							|  |  |  |         self.beg_str = "sos" | 
					
						
							|  |  |  |         self.end_str = "eos" | 
					
						
							|  |  |  |         list_character = [self.beg_str] + list_character + [self.end_str] | 
					
						
							|  |  |  |         return list_character | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds): | 
					
						
							|  |  |  |         structure_probs = preds['structure_probs'] | 
					
						
							|  |  |  |         loc_preds = preds['loc_preds'] | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |         if isinstance(structure_probs, paddle.Tensor): | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  |             structure_probs = structure_probs.numpy() | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |         if isinstance(loc_preds, paddle.Tensor): | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  |             loc_preds = loc_preds.numpy() | 
					
						
							|  |  |  |         structure_idx = structure_probs.argmax(axis=2) | 
					
						
							|  |  |  |         structure_probs = structure_probs.max(axis=2) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |         structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode( | 
					
						
							|  |  |  |             structure_idx, structure_probs, 'elem') | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  |         res_html_code_list = [] | 
					
						
							|  |  |  |         res_loc_list = [] | 
					
						
							|  |  |  |         batch_num = len(structure_str) | 
					
						
							|  |  |  |         for bno in range(batch_num): | 
					
						
							|  |  |  |             res_loc = [] | 
					
						
							|  |  |  |             for sno in range(len(structure_str[bno])): | 
					
						
							|  |  |  |                 text = structure_str[bno][sno] | 
					
						
							|  |  |  |                 if text in ['<td>', '<td']: | 
					
						
							|  |  |  |                     pos = structure_pos[bno][sno] | 
					
						
							|  |  |  |                     res_loc.append(loc_preds[bno, pos]) | 
					
						
							|  |  |  |             res_html_code = ''.join(structure_str[bno]) | 
					
						
							|  |  |  |             res_loc = np.array(res_loc) | 
					
						
							|  |  |  |             res_html_code_list.append(res_html_code) | 
					
						
							|  |  |  |             res_loc_list.append(res_loc) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |         return { | 
					
						
							|  |  |  |             'res_html_code': res_html_code_list, | 
					
						
							|  |  |  |             'res_loc': res_loc_list, | 
					
						
							|  |  |  |             'res_score_list': result_score_list, | 
					
						
							|  |  |  |             'res_elem_idx_list': result_elem_idx_list, | 
					
						
							|  |  |  |             'structure_str_list': structure_str | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-06-10 14:24:59 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def decode(self, text_index, structure_probs, char_or_elem): | 
					
						
							|  |  |  |         """convert text-label into text-index.
 | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if char_or_elem == "char": | 
					
						
							|  |  |  |             current_dict = self.dict_idx_character | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             current_dict = self.dict_idx_elem | 
					
						
							|  |  |  |             ignored_tokens = self.get_ignored_tokens('elem') | 
					
						
							|  |  |  |             beg_idx, end_idx = ignored_tokens | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         result_list = [] | 
					
						
							|  |  |  |         result_pos_list = [] | 
					
						
							|  |  |  |         result_score_list = [] | 
					
						
							|  |  |  |         result_elem_idx_list = [] | 
					
						
							|  |  |  |         batch_size = len(text_index) | 
					
						
							|  |  |  |         for batch_idx in range(batch_size): | 
					
						
							|  |  |  |             char_list = [] | 
					
						
							|  |  |  |             elem_pos_list = [] | 
					
						
							|  |  |  |             elem_idx_list = [] | 
					
						
							|  |  |  |             score_list = [] | 
					
						
							|  |  |  |             for idx in range(len(text_index[batch_idx])): | 
					
						
							|  |  |  |                 tmp_elem_idx = int(text_index[batch_idx][idx]) | 
					
						
							|  |  |  |                 if idx > 0 and tmp_elem_idx == end_idx: | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  |                 if tmp_elem_idx in ignored_tokens: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 char_list.append(current_dict[tmp_elem_idx]) | 
					
						
							|  |  |  |                 elem_pos_list.append(idx) | 
					
						
							|  |  |  |                 score_list.append(structure_probs[batch_idx, idx]) | 
					
						
							|  |  |  |                 elem_idx_list.append(tmp_elem_idx) | 
					
						
							|  |  |  |             result_list.append(char_list) | 
					
						
							|  |  |  |             result_pos_list.append(elem_pos_list) | 
					
						
							|  |  |  |             result_score_list.append(score_list) | 
					
						
							|  |  |  |             result_elem_idx_list.append(elem_idx_list) | 
					
						
							|  |  |  |         return result_list, result_pos_list, result_score_list, result_elem_idx_list | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_ignored_tokens(self, char_or_elem): | 
					
						
							|  |  |  |         beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) | 
					
						
							|  |  |  |         end_idx = self.get_beg_end_flag_idx("end", char_or_elem) | 
					
						
							|  |  |  |         return [beg_idx, end_idx] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): | 
					
						
							|  |  |  |         if char_or_elem == "char": | 
					
						
							|  |  |  |             if beg_or_end == "beg": | 
					
						
							|  |  |  |                 idx = self.dict_character[self.beg_str] | 
					
						
							|  |  |  |             elif beg_or_end == "end": | 
					
						
							|  |  |  |                 idx = self.dict_character[self.end_str] | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ | 
					
						
							|  |  |  |                               % beg_or_end | 
					
						
							|  |  |  |         elif char_or_elem == "elem": | 
					
						
							|  |  |  |             if beg_or_end == "beg": | 
					
						
							|  |  |  |                 idx = self.dict_elem[self.beg_str] | 
					
						
							|  |  |  |             elif beg_or_end == "end": | 
					
						
							|  |  |  |                 idx = self.dict_elem[self.end_str] | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ | 
					
						
							|  |  |  |                               % beg_or_end | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             assert False, "Unsupport type %s in char_or_elem" \ | 
					
						
							|  |  |  |                           % char_or_elem | 
					
						
							|  |  |  |         return idx | 
					
						
							| 
									
										
										
										
											2021-08-24 03:49:26 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SARLabelDecode(BaseRecLabelDecode): | 
					
						
							|  |  |  |     """ Convert between text-label and text-index """ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |     def __init__(self, character_dict_path=None, use_space_char=False, | 
					
						
							| 
									
										
										
										
											2021-08-24 03:49:26 +00:00
										 |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(SARLabelDecode, self).__init__(character_dict_path, | 
					
						
							| 
									
										
										
										
											2021-10-12 14:29:00 +08:00
										 |  |  |                                              use_space_char) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-07 11:31:23 +00:00
										 |  |  |         self.rm_symbol = kwargs.get('rm_symbol', False) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:49:26 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def add_special_char(self, dict_character): | 
					
						
							|  |  |  |         beg_end_str = "<BOS/EOS>" | 
					
						
							|  |  |  |         unknown_str = "<UKN>" | 
					
						
							|  |  |  |         padding_str = "<PAD>" | 
					
						
							|  |  |  |         dict_character = dict_character + [unknown_str] | 
					
						
							|  |  |  |         self.unknown_idx = len(dict_character) - 1 | 
					
						
							|  |  |  |         dict_character = dict_character + [beg_end_str] | 
					
						
							|  |  |  |         self.start_idx = len(dict_character) - 1 | 
					
						
							|  |  |  |         self.end_idx = len(dict_character) - 1 | 
					
						
							|  |  |  |         dict_character = dict_character + [padding_str] | 
					
						
							|  |  |  |         self.padding_idx = len(dict_character) - 1 | 
					
						
							|  |  |  |         return dict_character | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | 
					
						
							|  |  |  |         """ convert text-index into text-label. """ | 
					
						
							|  |  |  |         result_list = [] | 
					
						
							|  |  |  |         ignored_tokens = self.get_ignored_tokens() | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:49:26 +00:00
										 |  |  |         batch_size = len(text_index) | 
					
						
							|  |  |  |         for batch_idx in range(batch_size): | 
					
						
							|  |  |  |             char_list = [] | 
					
						
							|  |  |  |             conf_list = [] | 
					
						
							|  |  |  |             for idx in range(len(text_index[batch_idx])): | 
					
						
							|  |  |  |                 if text_index[batch_idx][idx] in ignored_tokens: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  |                 if int(text_index[batch_idx][idx]) == int(self.end_idx): | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  |                     if text_prob is None and idx == 0: | 
					
						
							| 
									
										
										
										
											2021-08-24 03:49:26 +00:00
										 |  |  |                         continue | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         break | 
					
						
							|  |  |  |                 if is_remove_duplicate: | 
					
						
							|  |  |  |                     # only for predict | 
					
						
							|  |  |  |                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ | 
					
						
							|  |  |  |                             batch_idx][idx]: | 
					
						
							|  |  |  |                         continue | 
					
						
							|  |  |  |                 char_list.append(self.character[int(text_index[batch_idx][ | 
					
						
							|  |  |  |                     idx])]) | 
					
						
							|  |  |  |                 if text_prob is not None: | 
					
						
							|  |  |  |                     conf_list.append(text_prob[batch_idx][idx]) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     conf_list.append(1) | 
					
						
							|  |  |  |             text = ''.join(char_list) | 
					
						
							| 
									
										
										
										
											2021-09-02 07:18:13 +00:00
										 |  |  |             if self.rm_symbol: | 
					
						
							|  |  |  |                 comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') | 
					
						
							|  |  |  |                 text = text.lower() | 
					
						
							|  |  |  |                 text = comp.sub('', text) | 
					
						
							| 
									
										
										
										
											2022-04-22 13:24:45 +08:00
										 |  |  |             result_list.append((text, np.mean(conf_list).tolist())) | 
					
						
							| 
									
										
										
										
											2021-08-24 03:49:26 +00:00
										 |  |  |         return result_list | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds, label=None, *args, **kwargs): | 
					
						
							|  |  |  |         if isinstance(preds, paddle.Tensor): | 
					
						
							|  |  |  |             preds = preds.numpy() | 
					
						
							|  |  |  |         preds_idx = preds.argmax(axis=2) | 
					
						
							|  |  |  |         preds_prob = preds.max(axis=2) | 
					
						
							| 
									
										
										
										
											2021-09-07 06:09:59 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-24 03:49:26 +00:00
										 |  |  |         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if label is None: | 
					
						
							|  |  |  |             return text | 
					
						
							|  |  |  |         label = self.decode(label, is_remove_duplicate=False) | 
					
						
							|  |  |  |         return text, label | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_ignored_tokens(self): | 
					
						
							|  |  |  |         return [self.padding_idx] | 
					
						
							| 
									
										
										
										
											2022-02-28 21:48:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-26 16:19:31 +08:00
										 |  |  | class DistillationSARLabelDecode(SARLabelDecode): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Convert  | 
					
						
							|  |  |  |     Convert between text-label and text-index | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  character_dict_path=None, | 
					
						
							|  |  |  |                  use_space_char=False, | 
					
						
							|  |  |  |                  model_name=["student"], | 
					
						
							|  |  |  |                  key=None, | 
					
						
							|  |  |  |                  multi_head=False, | 
					
						
							|  |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(DistillationSARLabelDecode, self).__init__(character_dict_path, | 
					
						
							|  |  |  |                                                          use_space_char) | 
					
						
							|  |  |  |         if not isinstance(model_name, list): | 
					
						
							|  |  |  |             model_name = [model_name] | 
					
						
							|  |  |  |         self.model_name = model_name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.key = key | 
					
						
							|  |  |  |         self.multi_head = multi_head | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds, label=None, *args, **kwargs): | 
					
						
							|  |  |  |         output = dict() | 
					
						
							|  |  |  |         for name in self.model_name: | 
					
						
							|  |  |  |             pred = preds[name] | 
					
						
							|  |  |  |             if self.key is not None: | 
					
						
							|  |  |  |                 pred = pred[self.key] | 
					
						
							|  |  |  |             if self.multi_head and isinstance(pred, dict): | 
					
						
							|  |  |  |                 pred = pred['sar'] | 
					
						
							|  |  |  |             output[name] = super().__call__(pred, label=label, *args, **kwargs) | 
					
						
							|  |  |  |         return output | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-28 21:48:00 +08:00
										 |  |  | class PRENLabelDecode(BaseRecLabelDecode): | 
					
						
							|  |  |  |     """ Convert between text-label and text-index """ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, character_dict_path=None, use_space_char=False, | 
					
						
							|  |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(PRENLabelDecode, self).__init__(character_dict_path, | 
					
						
							|  |  |  |                                               use_space_char) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_special_char(self, dict_character): | 
					
						
							|  |  |  |         padding_str = '<PAD>'  # 0  | 
					
						
							|  |  |  |         end_str = '<EOS>'  # 1 | 
					
						
							|  |  |  |         unknown_str = '<UNK>'  # 2 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         dict_character = [padding_str, end_str, unknown_str] + dict_character | 
					
						
							|  |  |  |         self.padding_idx = 0 | 
					
						
							|  |  |  |         self.end_idx = 1 | 
					
						
							|  |  |  |         self.unknown_idx = 2 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return dict_character | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def decode(self, text_index, text_prob=None): | 
					
						
							|  |  |  |         """ convert text-index into text-label. """ | 
					
						
							|  |  |  |         result_list = [] | 
					
						
							|  |  |  |         batch_size = len(text_index) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for batch_idx in range(batch_size): | 
					
						
							|  |  |  |             char_list = [] | 
					
						
							|  |  |  |             conf_list = [] | 
					
						
							|  |  |  |             for idx in range(len(text_index[batch_idx])): | 
					
						
							|  |  |  |                 if text_index[batch_idx][idx] == self.end_idx: | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  |                 if text_index[batch_idx][idx] in \ | 
					
						
							|  |  |  |                     [self.padding_idx, self.unknown_idx]: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  |                 char_list.append(self.character[int(text_index[batch_idx][ | 
					
						
							|  |  |  |                     idx])]) | 
					
						
							|  |  |  |                 if text_prob is not None: | 
					
						
							|  |  |  |                     conf_list.append(text_prob[batch_idx][idx]) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     conf_list.append(1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             text = ''.join(char_list) | 
					
						
							|  |  |  |             if len(text) > 0: | 
					
						
							| 
									
										
										
										
											2022-04-22 13:24:45 +08:00
										 |  |  |                 result_list.append((text, np.mean(conf_list).tolist())) | 
					
						
							| 
									
										
										
										
											2022-02-28 21:48:00 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 # here confidence of empty recog result is 1 | 
					
						
							|  |  |  |                 result_list.append(('', 1)) | 
					
						
							|  |  |  |         return result_list | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds, label=None, *args, **kwargs): | 
					
						
							|  |  |  |         preds = preds.numpy() | 
					
						
							|  |  |  |         preds_idx = preds.argmax(axis=2) | 
					
						
							|  |  |  |         preds_prob = preds.max(axis=2) | 
					
						
							|  |  |  |         text = self.decode(preds_idx, preds_prob) | 
					
						
							|  |  |  |         if label is None: | 
					
						
							|  |  |  |             return text | 
					
						
							|  |  |  |         label = self.decode(label) | 
					
						
							|  |  |  |         return text, label | 
					
						
							| 
									
										
										
										
											2022-04-26 10:30:26 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SVTRLabelDecode(BaseRecLabelDecode): | 
					
						
							|  |  |  |     """ Convert between text-label and text-index """ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, character_dict_path=None, use_space_char=False, | 
					
						
							|  |  |  |                  **kwargs): | 
					
						
							|  |  |  |         super(SVTRLabelDecode, self).__init__(character_dict_path, | 
					
						
							|  |  |  |                                              use_space_char) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds, label=None, *args, **kwargs): | 
					
						
							|  |  |  |         if isinstance(preds, tuple): | 
					
						
							|  |  |  |             preds = preds[-1] | 
					
						
							|  |  |  |         if isinstance(preds, paddle.Tensor): | 
					
						
							|  |  |  |             preds = preds.numpy() | 
					
						
							|  |  |  |         preds_idx = preds.argmax(axis=-1) | 
					
						
							|  |  |  |         preds_prob = preds.max(axis=-1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) | 
					
						
							|  |  |  |         return_text = [] | 
					
						
							|  |  |  |         for i in range(0, len(text), 3): | 
					
						
							|  |  |  |             text0 = text[i] | 
					
						
							|  |  |  |             text1 = text[i + 1] | 
					
						
							|  |  |  |             text2 = text[i + 2] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             text_pred = [text0[0], text1[0], text2[0]] | 
					
						
							|  |  |  |             text_prob = [text0[1], text1[1], text2[1]] | 
					
						
							|  |  |  |             id_max = text_prob.index(max(text_prob)) | 
					
						
							|  |  |  |             return_text.append((text_pred[id_max], text_prob[id_max])) | 
					
						
							|  |  |  |         if label is None: | 
					
						
							|  |  |  |             return return_text | 
					
						
							|  |  |  |         label = self.decode(label) | 
					
						
							|  |  |  |         return return_text, label | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_special_char(self, dict_character): | 
					
						
							|  |  |  |         dict_character = ['blank'] + dict_character | 
					
						
							|  |  |  |         return dict_character |