| 
									
										
										
										
											2021-01-20 19:06:39 +08:00
										 |  |  | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #     http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  | import yaml | 
					
						
							|  |  |  | from argparse import ArgumentParser, RawDescriptionHelpFormatter | 
					
						
							| 
									
										
										
										
											2021-01-19 23:46:35 +08:00
										 |  |  | import os.path | 
					
						
							| 
									
										
										
										
											2021-01-20 12:08:57 +08:00
										 |  |  | import logging | 
					
						
							|  |  |  | logging.basicConfig(level=logging.INFO) | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | support_list = { | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |     'it': 'italian', | 
					
						
							|  |  |  |     'xi': 'spanish', | 
					
						
							|  |  |  |     'pu': 'portuguese', | 
					
						
							|  |  |  |     'ru': 'russian', | 
					
						
							|  |  |  |     'ar': 'arabic', | 
					
						
							|  |  |  |     'ta': 'tamil', | 
					
						
							|  |  |  |     'ug': 'uyghur', | 
					
						
							|  |  |  |     'fa': 'persian', | 
					
						
							|  |  |  |     'ur': 'urdu', | 
					
						
							|  |  |  |     'rs': 'serbian latin', | 
					
						
							|  |  |  |     'oc': 'occitan', | 
					
						
							|  |  |  |     'rsc': 'serbian cyrillic', | 
					
						
							|  |  |  |     'bg': 'bulgarian', | 
					
						
							|  |  |  |     'uk': 'ukranian', | 
					
						
							|  |  |  |     'be': 'belarusian', | 
					
						
							|  |  |  |     'te': 'telugu', | 
					
						
							|  |  |  |     'ka': 'kannada', | 
					
						
							|  |  |  |     'chinese_cht': 'chinese tradition', | 
					
						
							|  |  |  |     'hi': 'hindi', | 
					
						
							|  |  |  |     'mr': 'marathi', | 
					
						
							|  |  |  |     'ne': 'nepali', | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | latin_lang = [ | 
					
						
							|  |  |  |     'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', | 
					
						
							|  |  |  |     'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl', | 
					
						
							|  |  |  |     'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv', | 
					
						
							|  |  |  |     'sw', 'tl', 'tr', 'uz', 'vi', 'latin' | 
					
						
							|  |  |  | ] | 
					
						
							|  |  |  | arabic_lang = ['ar', 'fa', 'ug', 'ur'] | 
					
						
							|  |  |  | cyrillic_lang = [ | 
					
						
							|  |  |  |     'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', | 
					
						
							|  |  |  |     'dar', 'inh', 'che', 'lbe', 'lez', 'tab', 'cyrillic' | 
					
						
							|  |  |  | ] | 
					
						
							|  |  |  | devanagari_lang = [ | 
					
						
							|  |  |  |     'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', | 
					
						
							|  |  |  |     'sa', 'bgc', 'devanagari' | 
					
						
							|  |  |  | ] | 
					
						
							|  |  |  | multi_lang = latin_lang + arabic_lang + cyrillic_lang + devanagari_lang | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | assert (os.path.isfile("./rec_multi_language_lite_train.yml") | 
					
						
							|  |  |  |         ), "Loss basic configuration file rec_multi_language_lite_train.yml.\
 | 
					
						
							| 
									
										
										
										
											2021-01-19 23:46:35 +08:00
										 |  |  | You can download it from \ | 
					
						
							|  |  |  | https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/" | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | global_config = yaml.load( | 
					
						
							|  |  |  |     open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader) | 
					
						
							| 
									
										
										
										
											2021-01-20 12:08:57 +08:00
										 |  |  | project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../")) | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  | class ArgsParser(ArgumentParser): | 
					
						
							|  |  |  |     def __init__(self): | 
					
						
							|  |  |  |         super(ArgsParser, self).__init__( | 
					
						
							|  |  |  |             formatter_class=RawDescriptionHelpFormatter) | 
					
						
							|  |  |  |         self.add_argument( | 
					
						
							|  |  |  |             "-o", "--opt", nargs='+', help="set configuration options") | 
					
						
							|  |  |  |         self.add_argument( | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |             "-l", | 
					
						
							|  |  |  |             "--language", | 
					
						
							|  |  |  |             nargs='+', | 
					
						
							|  |  |  |             help="set language type, support {}".format(support_list)) | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |         self.add_argument( | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |             "--train", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             help="you can use this command to change the train dataset default path" | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |         self.add_argument( | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |             "--val", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             help="you can use this command to change the eval dataset default path" | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |         self.add_argument( | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |             "--dict", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             help="you can use this command to change the dictionary default path" | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-01-19 23:46:35 +08:00
										 |  |  |         self.add_argument( | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |             "--data_dir", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             help="you can use this command to change the dataset default root path" | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def parse_args(self, argv=None): | 
					
						
							|  |  |  |         args = super(ArgsParser, self).parse_args(argv) | 
					
						
							|  |  |  |         args.opt = self._parse_opt(args.opt) | 
					
						
							| 
									
										
										
										
											2021-01-19 23:46:35 +08:00
										 |  |  |         args.language = self._set_language(args.language) | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |         return args | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _parse_opt(self, opts): | 
					
						
							|  |  |  |         config = {} | 
					
						
							|  |  |  |         if not opts: | 
					
						
							|  |  |  |             return config | 
					
						
							|  |  |  |         for s in opts: | 
					
						
							|  |  |  |             s = s.strip() | 
					
						
							|  |  |  |             k, v = s.split('=') | 
					
						
							|  |  |  |             config[k] = yaml.load(v, Loader=yaml.Loader) | 
					
						
							|  |  |  |         return config | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _set_language(self, type): | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |         lang = type[0] | 
					
						
							|  |  |  |         assert (type), "please use -l or --language to choose language type" | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |         assert( | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |                 lang in support_list.keys() or lang in multi_lang | 
					
						
							| 
									
										
										
										
											2021-01-20 12:08:57 +08:00
										 |  |  |                ),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \ | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |                  "please check your running command".format(multi_lang, type) | 
					
						
							|  |  |  |         if lang in latin_lang: | 
					
						
							|  |  |  |             lang = "latin" | 
					
						
							|  |  |  |         elif lang in arabic_lang: | 
					
						
							|  |  |  |             lang = "arabic" | 
					
						
							|  |  |  |         elif lang in cyrillic_lang: | 
					
						
							|  |  |  |             lang = "cyrillic" | 
					
						
							|  |  |  |         elif lang in devanagari_lang: | 
					
						
							|  |  |  |             lang = "devanagari" | 
					
						
							|  |  |  |         global_config['Global'][ | 
					
						
							|  |  |  |             'character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(lang) | 
					
						
							|  |  |  |         global_config['Global'][ | 
					
						
							|  |  |  |             'save_model_dir'] = './output/rec_{}_lite'.format(lang) | 
					
						
							|  |  |  |         global_config['Train']['dataset'][ | 
					
						
							|  |  |  |             'label_file_list'] = ["train_data/{}_train.txt".format(lang)] | 
					
						
							|  |  |  |         global_config['Eval']['dataset'][ | 
					
						
							|  |  |  |             'label_file_list'] = ["train_data/{}_val.txt".format(lang)] | 
					
						
							|  |  |  |         global_config['Global']['character_type'] = lang | 
					
						
							|  |  |  |         assert ( | 
					
						
							|  |  |  |             os.path.isfile( | 
					
						
							|  |  |  |                 os.path.join(project_path, global_config['Global'][ | 
					
						
							|  |  |  |                     'character_dict_path'])) | 
					
						
							|  |  |  |         ), "Loss default dictionary file {}_dict.txt.You can download it from \
 | 
					
						
							|  |  |  | https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format( | 
					
						
							|  |  |  |             lang) | 
					
						
							|  |  |  |         return lang | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-19 23:46:35 +08:00
										 |  |  | def merge_config(config): | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Merge config into global config. | 
					
						
							|  |  |  |     Args: | 
					
						
							| 
									
										
										
										
											2021-01-19 23:46:35 +08:00
										 |  |  |         config (dict): Config to be merged. | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |     Returns: global config | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2021-01-19 23:46:35 +08:00
										 |  |  |     for key, value in config.items(): | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |         if "." not in key: | 
					
						
							|  |  |  |             if isinstance(value, dict) and key in global_config: | 
					
						
							|  |  |  |                 global_config[key].update(value) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 global_config[key] = value | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             sub_keys = key.split('.') | 
					
						
							|  |  |  |             assert ( | 
					
						
							|  |  |  |                 sub_keys[0] in global_config | 
					
						
							| 
									
										
										
										
											2021-01-19 23:46:35 +08:00
										 |  |  |             ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |                 global_config.keys(), sub_keys[0]) | 
					
						
							|  |  |  |             cur = global_config[sub_keys[0]] | 
					
						
							|  |  |  |             for idx, sub_key in enumerate(sub_keys[1:]): | 
					
						
							|  |  |  |                 if idx == len(sub_keys) - 2: | 
					
						
							|  |  |  |                     cur[sub_key] = value | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     cur = cur[sub_key] | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-20 12:08:57 +08:00
										 |  |  | def loss_file(path): | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |     assert ( | 
					
						
							|  |  |  |         os.path.exists(path) | 
					
						
							|  |  |  |     ), "There is no such file:{},Please do not forget to put in the specified file".format( | 
					
						
							|  |  |  |         path) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |     FLAGS = ArgsParser().parse_args() | 
					
						
							| 
									
										
										
										
											2021-01-19 23:46:35 +08:00
										 |  |  |     merge_config(FLAGS.opt) | 
					
						
							| 
									
										
										
										
											2021-01-20 13:07:35 +08:00
										 |  |  |     save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language) | 
					
						
							|  |  |  |     if os.path.isfile(save_file_path): | 
					
						
							|  |  |  |         os.remove(save_file_path) | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |     if FLAGS.train: | 
					
						
							|  |  |  |         global_config['Train']['dataset']['label_file_list'] = [FLAGS.train] | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |         train_label_path = os.path.join(project_path, FLAGS.train) | 
					
						
							| 
									
										
										
										
											2021-01-20 12:08:57 +08:00
										 |  |  |         loss_file(train_label_path) | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |     if FLAGS.val: | 
					
						
							|  |  |  |         global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val] | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |         eval_label_path = os.path.join(project_path, FLAGS.val) | 
					
						
							| 
									
										
										
										
											2021-03-03 11:09:20 +08:00
										 |  |  |         loss_file(eval_label_path) | 
					
						
							| 
									
										
										
										
											2021-01-19 15:52:04 +08:00
										 |  |  |     if FLAGS.dict: | 
					
						
							|  |  |  |         global_config['Global']['character_dict_path'] = FLAGS.dict | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |         dict_path = os.path.join(project_path, FLAGS.dict) | 
					
						
							| 
									
										
										
										
											2021-01-20 12:08:57 +08:00
										 |  |  |         loss_file(dict_path) | 
					
						
							|  |  |  |     if FLAGS.data_dir: | 
					
						
							|  |  |  |         global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir | 
					
						
							|  |  |  |         global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |         data_dir = os.path.join(project_path, FLAGS.data_dir) | 
					
						
							| 
									
										
										
										
											2021-01-20 12:08:57 +08:00
										 |  |  |         loss_file(data_dir) | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-19 23:46:35 +08:00
										 |  |  |     with open(save_file_path, 'w') as f: | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |         yaml.dump( | 
					
						
							|  |  |  |             dict(global_config), f, default_flow_style=False, sort_keys=False) | 
					
						
							| 
									
										
										
										
											2021-01-20 12:08:57 +08:00
										 |  |  |     logging.info("Project path is          :{}".format(project_path)) | 
					
						
							| 
									
										
										
										
											2021-04-13 17:54:10 +08:00
										 |  |  |     logging.info("Train list path set to   :{}".format(global_config['Train'][ | 
					
						
							|  |  |  |         'dataset']['label_file_list'][0])) | 
					
						
							|  |  |  |     logging.info("Eval list path set to    :{}".format(global_config['Eval'][ | 
					
						
							|  |  |  |         'dataset']['label_file_list'][0])) | 
					
						
							|  |  |  |     logging.info("Dataset root path set to :{}".format(global_config['Eval'][ | 
					
						
							|  |  |  |         'dataset']['data_dir'])) | 
					
						
							|  |  |  |     logging.info("Dict path set to         :{}".format(global_config['Global'][ | 
					
						
							|  |  |  |         'character_dict_path'])) | 
					
						
							|  |  |  |     logging.info("Config file set to       :configs/rec/multi_language/{}". | 
					
						
							|  |  |  |                  format(save_file_path)) |