| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #    http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from __future__ import absolute_import | 
					
						
							|  |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							|  |  |  | from __future__ import unicode_literals | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import copy | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | __all__ = ['build_post_process'] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-07 01:54:03 +00:00
										 |  |  | from .db_postprocess import DBPostProcess, DistillationDBPostProcess | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  | from .east_postprocess import EASTPostProcess | 
					
						
							|  |  |  | from .sast_postprocess import SASTPostProcess | 
					
						
							| 
									
										
										
										
											2022-01-27 17:36:19 +08:00
										 |  |  | from .fce_postprocess import FCEPostProcess | 
					
						
							| 
									
										
										
										
											2022-02-28 21:48:00 +08:00
										 |  |  | from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ | 
					
						
							|  |  |  |     DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ | 
					
						
							|  |  |  |     SEEDLabelDecode, PRENLabelDecode | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  | from .cls_postprocess import ClsPostProcess | 
					
						
							|  |  |  | from .pg_postprocess import PGPostProcess | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess | 
					
						
							|  |  |  | from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess | 
					
						
							| 
									
										
										
										
											2021-09-27 19:43:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-21 08:10:45 -06:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-02 08:31:57 +00:00
										 |  |  | def build_post_process(config, global_config=None): | 
					
						
							| 
									
										
										
										
											2020-11-21 08:10:45 -06:00
										 |  |  |     support_dict = [ | 
					
						
							| 
									
										
										
										
											2022-01-27 17:36:19 +08:00
										 |  |  |         'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess', | 
					
						
							|  |  |  |         'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', | 
					
						
							|  |  |  |         'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', | 
					
						
							| 
									
										
										
										
											2021-09-27 14:58:10 +08:00
										 |  |  |         'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |         'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', | 
					
						
							| 
									
										
										
										
											2022-02-28 21:48:00 +08:00
										 |  |  |         'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode' | 
					
						
							| 
									
										
										
										
											2020-11-21 08:10:45 -06:00
										 |  |  |     ] | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-02 16:10:43 +08:00
										 |  |  |     if config['name'] == 'PSEPostProcess': | 
					
						
							|  |  |  |         from .pse_postprocess import PSEPostProcess | 
					
						
							|  |  |  |         support_dict.append('PSEPostProcess') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     config = copy.deepcopy(config) | 
					
						
							|  |  |  |     module_name = config.pop('name') | 
					
						
							| 
									
										
										
										
											2021-10-09 15:40:25 +08:00
										 |  |  |     if module_name == "None": | 
					
						
							|  |  |  |         return | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     if global_config is not None: | 
					
						
							|  |  |  |         config.update(global_config) | 
					
						
							|  |  |  |     assert module_name in support_dict, Exception( | 
					
						
							|  |  |  |         'post process only support {}'.format(support_dict)) | 
					
						
							|  |  |  |     module_class = eval(module_name)(**config) | 
					
						
							|  |  |  |     return module_class |