| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #    http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | from paddle.io import Dataset | 
					
						
							|  |  |  | import lmdb | 
					
						
							|  |  |  | import cv2 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from .imaug import transform, create_operators | 
					
						
							| 
									
										
										
										
											2020-11-05 15:13:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-29 15:03:41 +08:00
										 |  |  | class LMDBDataSet(Dataset): | 
					
						
							| 
									
										
										
										
											2021-01-24 12:32:22 +00:00
										 |  |  |     def __init__(self, config, mode, logger, seed=None): | 
					
						
							| 
									
										
										
										
											2021-01-29 15:03:41 +08:00
										 |  |  |         super(LMDBDataSet, self).__init__() | 
					
						
							| 
									
										
										
										
											2020-11-05 15:13:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |         global_config = config['Global'] | 
					
						
							|  |  |  |         dataset_config = config[mode]['dataset'] | 
					
						
							|  |  |  |         loader_config = config[mode]['loader'] | 
					
						
							|  |  |  |         batch_size = loader_config['batch_size_per_card'] | 
					
						
							|  |  |  |         data_dir = dataset_config['data_dir'] | 
					
						
							|  |  |  |         self.do_shuffle = loader_config['shuffle'] | 
					
						
							| 
									
										
										
										
											2020-11-05 15:13:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |         self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) | 
					
						
							|  |  |  |         logger.info("Initialize indexs of datasets:%s" % data_dir) | 
					
						
							|  |  |  |         self.data_idx_order_list = self.dataset_traversal() | 
					
						
							|  |  |  |         if self.do_shuffle: | 
					
						
							|  |  |  |             np.random.shuffle(self.data_idx_order_list) | 
					
						
							|  |  |  |         self.ops = create_operators(dataset_config['transforms'], global_config) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-01-12 09:54:07 +00:00
										 |  |  |         ratio_list = dataset_config.get("ratio_list", [1.0]) | 
					
						
							|  |  |  |         self.need_reset = True in [x < 1 for x in ratio_list] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |     def load_hierarchical_lmdb_dataset(self, data_dir): | 
					
						
							|  |  |  |         lmdb_sets = {} | 
					
						
							|  |  |  |         dataset_idx = 0 | 
					
						
							|  |  |  |         for dirpath, dirnames, filenames in os.walk(data_dir + '/'): | 
					
						
							|  |  |  |             if not dirnames: | 
					
						
							|  |  |  |                 env = lmdb.open( | 
					
						
							|  |  |  |                     dirpath, | 
					
						
							|  |  |  |                     max_readers=32, | 
					
						
							|  |  |  |                     readonly=True, | 
					
						
							|  |  |  |                     lock=False, | 
					
						
							|  |  |  |                     readahead=False, | 
					
						
							|  |  |  |                     meminit=False) | 
					
						
							|  |  |  |                 txn = env.begin(write=False) | 
					
						
							|  |  |  |                 num_samples = int(txn.get('num-samples'.encode())) | 
					
						
							|  |  |  |                 lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \ | 
					
						
							|  |  |  |                     "txn":txn, "num_samples":num_samples} | 
					
						
							|  |  |  |                 dataset_idx += 1 | 
					
						
							|  |  |  |         return lmdb_sets | 
					
						
							| 
									
										
										
										
											2020-11-05 15:13:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |     def dataset_traversal(self): | 
					
						
							|  |  |  |         lmdb_num = len(self.lmdb_sets) | 
					
						
							|  |  |  |         total_sample_num = 0 | 
					
						
							|  |  |  |         for lno in range(lmdb_num): | 
					
						
							|  |  |  |             total_sample_num += self.lmdb_sets[lno]['num_samples'] | 
					
						
							|  |  |  |         data_idx_order_list = np.zeros((total_sample_num, 2)) | 
					
						
							|  |  |  |         beg_idx = 0 | 
					
						
							|  |  |  |         for lno in range(lmdb_num): | 
					
						
							|  |  |  |             tmp_sample_num = self.lmdb_sets[lno]['num_samples'] | 
					
						
							|  |  |  |             end_idx = beg_idx + tmp_sample_num | 
					
						
							|  |  |  |             data_idx_order_list[beg_idx:end_idx, 0] = lno | 
					
						
							|  |  |  |             data_idx_order_list[beg_idx:end_idx, 1] \ | 
					
						
							|  |  |  |                 = list(range(tmp_sample_num)) | 
					
						
							|  |  |  |             data_idx_order_list[beg_idx:end_idx, 1] += 1 | 
					
						
							|  |  |  |             beg_idx = beg_idx + tmp_sample_num | 
					
						
							|  |  |  |         return data_idx_order_list | 
					
						
							| 
									
										
										
										
											2020-11-05 15:13:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |     def get_img_data(self, value): | 
					
						
							|  |  |  |         """get_img_data""" | 
					
						
							|  |  |  |         if not value: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         imgdata = np.frombuffer(value, dtype='uint8') | 
					
						
							|  |  |  |         if imgdata is None: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         imgori = cv2.imdecode(imgdata, 1) | 
					
						
							|  |  |  |         if imgori is None: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         return imgori | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_lmdb_sample_info(self, txn, index): | 
					
						
							|  |  |  |         label_key = 'label-%09d'.encode() % index | 
					
						
							|  |  |  |         label = txn.get(label_key) | 
					
						
							|  |  |  |         if label is None: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         label = label.decode('utf-8') | 
					
						
							|  |  |  |         img_key = 'image-%09d'.encode() % index | 
					
						
							|  |  |  |         imgbuf = txn.get(img_key) | 
					
						
							|  |  |  |         return imgbuf, label | 
					
						
							| 
									
										
										
										
											2020-11-05 15:13:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |     def __getitem__(self, idx): | 
					
						
							|  |  |  |         lmdb_idx, file_idx = self.data_idx_order_list[idx] | 
					
						
							|  |  |  |         lmdb_idx = int(lmdb_idx) | 
					
						
							|  |  |  |         file_idx = int(file_idx) | 
					
						
							| 
									
										
										
										
											2020-11-05 15:13:36 +08:00
										 |  |  |         sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], | 
					
						
							|  |  |  |                                                 file_idx) | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |         if sample_info is None: | 
					
						
							| 
									
										
										
										
											2020-11-05 15:13:36 +08:00
										 |  |  |             return self.__getitem__(np.random.randint(self.__len__())) | 
					
						
							| 
									
										
										
										
											2020-11-04 20:43:27 +08:00
										 |  |  |         img, label = sample_info | 
					
						
							|  |  |  |         data = {'image': img, 'label': label} | 
					
						
							|  |  |  |         outs = transform(data, self.ops) | 
					
						
							|  |  |  |         if outs is None: | 
					
						
							|  |  |  |             return self.__getitem__(np.random.randint(self.__len__())) | 
					
						
							|  |  |  |         return outs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __len__(self): | 
					
						
							|  |  |  |         return self.data_idx_order_list.shape[0] |