| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #     http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from __future__ import absolute_import | 
					
						
							|  |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-12 13:49:24 +08:00
										 |  |  | import os | 
					
						
							|  |  |  | import sys | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-12 12:56:44 +08:00
										 |  |  | __dir__ = os.path.dirname(os.path.abspath(__file__)) | 
					
						
							| 
									
										
										
										
											2020-06-12 13:49:24 +08:00
										 |  |  | sys.path.append(__dir__) | 
					
						
							| 
									
										
										
										
											2020-08-12 12:56:44 +08:00
										 |  |  | sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | from ppocr.data import build_dataloader | 
					
						
							| 
									
										
										
										
											2020-11-09 13:28:15 +08:00
										 |  |  | from ppocr.modeling.architectures import build_model | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | from ppocr.postprocess import build_post_process | 
					
						
							|  |  |  | from ppocr.metrics import build_metric | 
					
						
							| 
									
										
										
										
											2021-11-12 11:06:36 +08:00
										 |  |  | from ppocr.utils.save_load import load_model | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | import tools.program as program | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  | def main(): | 
					
						
							|  |  |  |     global_config = config['Global'] | 
					
						
							|  |  |  |     # build dataloader | 
					
						
							| 
									
										
										
										
											2020-11-09 13:28:15 +08:00
										 |  |  |     valid_dataloader = build_dataloader(config, 'Eval', device, logger) | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     # build post process | 
					
						
							|  |  |  |     post_process_class = build_post_process(config['PostProcess'], | 
					
						
							|  |  |  |                                             global_config) | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     # build model | 
					
						
							|  |  |  |     # for rec algorithm | 
					
						
							|  |  |  |     if hasattr(post_process_class, 'character'): | 
					
						
							| 
									
										
										
										
											2021-06-17 13:29:49 +08:00
										 |  |  |         char_num = len(getattr(post_process_class, 'character')) | 
					
						
							|  |  |  |         if config['Architecture']["algorithm"] in ["Distillation", | 
					
						
							|  |  |  |                                                    ]:  # distillation model | 
					
						
							|  |  |  |             for key in config['Architecture']["Models"]: | 
					
						
							|  |  |  |                 config['Architecture']["Models"][key]["Head"][ | 
					
						
							|  |  |  |                     'out_channels'] = char_num | 
					
						
							|  |  |  |         else:  # base rec model | 
					
						
							|  |  |  |             config['Architecture']["Head"]['out_channels'] = char_num | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     model = build_model(config['Architecture']) | 
					
						
							| 
									
										
										
										
											2021-10-11 02:37:45 +00:00
										 |  |  |     extra_input = config['Architecture'][ | 
					
						
							|  |  |  |         'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] | 
					
						
							| 
									
										
										
										
											2021-07-09 14:29:39 +08:00
										 |  |  |     if "model_type" in config['Architecture'].keys(): | 
					
						
							|  |  |  |         model_type = config['Architecture']['model_type'] | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         model_type = None | 
					
						
							| 
									
										
										
										
											2021-07-08 14:32:44 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |     best_model_dict = load_model( | 
					
						
							|  |  |  |         config, model, model_type=config['Architecture']["model_type"]) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     if len(best_model_dict): | 
					
						
							|  |  |  |         logger.info('metric in ckpt ***************') | 
					
						
							|  |  |  |         for k, v in best_model_dict.items(): | 
					
						
							|  |  |  |             logger.info('{}:{}'.format(k, v)) | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     # build metric | 
					
						
							|  |  |  |     eval_class = build_metric(config['Metric']) | 
					
						
							|  |  |  |     # start eval | 
					
						
							| 
									
										
										
										
											2021-03-22 12:54:17 +08:00
										 |  |  |     metric = program.eval(model, valid_dataloader, post_process_class, | 
					
						
							| 
									
										
										
										
											2021-09-28 16:25:43 +08:00
										 |  |  |                           eval_class, model_type, extra_input) | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |     logger.info('metric eval ***************') | 
					
						
							| 
									
										
										
										
											2021-03-22 12:54:17 +08:00
										 |  |  |     for k, v in metric.items(): | 
					
						
							| 
									
										
										
										
											2020-10-13 17:13:33 +08:00
										 |  |  |         logger.info('{}:{}'.format(k, v)) | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == '__main__': | 
					
						
							| 
									
										
										
										
											2020-11-09 13:28:15 +08:00
										 |  |  |     config, device, logger, vdl_writer = program.preprocess() | 
					
						
							| 
									
										
										
										
											2020-05-10 16:26:57 +08:00
										 |  |  |     main() |