mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-29 16:59:21 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			102 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			102 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| from __future__ import absolute_import
 | |
| from __future__ import division
 | |
| from __future__ import print_function
 | |
| from __future__ import unicode_literals
 | |
| 
 | |
| import os
 | |
| import sys
 | |
| import numpy as np
 | |
| import paddle
 | |
| import signal
 | |
| import random
 | |
| 
 | |
| __dir__ = os.path.dirname(os.path.abspath(__file__))
 | |
| sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
 | |
| 
 | |
| import copy
 | |
| from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
 | |
| import paddle.distributed as dist
 | |
| 
 | |
| from ppocr.data.imaug import transform, create_operators
 | |
| from ppocr.data.simple_dataset import SimpleDataSet
 | |
| from ppocr.data.lmdb_dataset import LMDBDataSet
 | |
| from ppocr.data.pgnet_dataset import PGDataSet
 | |
| from ppocr.data.pubtab_dataset import PubTabDataSet
 | |
| 
 | |
| __all__ = ['build_dataloader', 'transform', 'create_operators']
 | |
| 
 | |
| 
 | |
| def term_mp(sig_num, frame):
 | |
|     """ kill all child processes
 | |
|     """
 | |
|     pid = os.getpid()
 | |
|     pgid = os.getpgid(os.getpid())
 | |
|     print("main proc {} exit, kill process group " "{}".format(pid, pgid))
 | |
|     os.killpg(pgid, signal.SIGKILL)
 | |
| 
 | |
| 
 | |
| def build_dataloader(config, mode, device, logger, seed=None):
 | |
|     config = copy.deepcopy(config)
 | |
| 
 | |
|     support_dict = [
 | |
|         'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'
 | |
|     ]
 | |
|     module_name = config[mode]['dataset']['name']
 | |
|     assert module_name in support_dict, Exception(
 | |
|         'DataSet only support {}'.format(support_dict))
 | |
|     assert mode in ['Train', 'Eval', 'Test'
 | |
|                     ], "Mode should be Train, Eval or Test."
 | |
| 
 | |
|     dataset = eval(module_name)(config, mode, logger, seed)
 | |
|     loader_config = config[mode]['loader']
 | |
|     batch_size = loader_config['batch_size_per_card']
 | |
|     drop_last = loader_config['drop_last']
 | |
|     shuffle = loader_config['shuffle']
 | |
|     num_workers = loader_config['num_workers']
 | |
|     if 'use_shared_memory' in loader_config.keys():
 | |
|         use_shared_memory = loader_config['use_shared_memory']
 | |
|     else:
 | |
|         use_shared_memory = True
 | |
|     if mode == "Train":
 | |
|         # Distribute data to multiple cards
 | |
|         batch_sampler = DistributedBatchSampler(
 | |
|             dataset=dataset,
 | |
|             batch_size=batch_size,
 | |
|             shuffle=shuffle,
 | |
|             drop_last=drop_last)
 | |
|     else:
 | |
|         # Distribute data to single card
 | |
|         batch_sampler = BatchSampler(
 | |
|             dataset=dataset,
 | |
|             batch_size=batch_size,
 | |
|             shuffle=shuffle,
 | |
|             drop_last=drop_last)
 | |
| 
 | |
|     data_loader = DataLoader(
 | |
|         dataset=dataset,
 | |
|         batch_sampler=batch_sampler,
 | |
|         places=device,
 | |
|         num_workers=num_workers,
 | |
|         return_list=True,
 | |
|         use_shared_memory=use_shared_memory)
 | |
| 
 | |
|     # support exit using ctrl+c
 | |
|     signal.signal(signal.SIGINT, term_mp)
 | |
|     signal.signal(signal.SIGTERM, term_mp)
 | |
| 
 | |
|     return data_loader
 | 
