107 lines
3.5 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:52
# @Author : zhoujun
import copy
import PIL
import numpy as np
import paddle
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
from paddle.vision import transforms
def get_dataset(data_path, module_name, transform, dataset_args):
"""
获取训练dataset
:param data_path: dataset文件列表每个文件内以如下格式存储 path/to/img\tlabel
:param module_name: 所使用的自定义dataset名称目前只支持data_loaders.ImageDataset
:param transform: 该数据集使用的transforms
:param dataset_args: module_name的参数
:return: 如果data_path列表不为空返回对于的ConcatDataset对象否则None
"""
from . import dataset
s_dataset = getattr(dataset, module_name)(transform=transform,
data_path=data_path,
**dataset_args)
return s_dataset
def get_transforms(transforms_config):
tr_list = []
for item in transforms_config:
if 'args' not in item:
args = {}
else:
args = item['args']
cls = getattr(transforms, item['type'])(**args)
tr_list.append(cls)
tr_list = transforms.Compose(tr_list)
return tr_list
class ICDARCollectFN:
def __init__(self, *args, **kwargs):
pass
def __call__(self, batch):
data_dict = {}
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, PIL.Image.Image)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
data_dict[k].append(v)
for k in to_tensor_keys:
data_dict[k] = paddle.stack(data_dict[k], 0)
return data_dict
def get_dataloader(module_config, distributed=False):
if module_config is None:
return None
config = copy.deepcopy(module_config)
dataset_args = config['dataset']['args']
if 'transforms' in dataset_args:
img_transfroms = get_transforms(dataset_args.pop('transforms'))
else:
img_transfroms = None
# 创建数据集
dataset_name = config['dataset']['type']
data_path = dataset_args.pop('data_path')
if data_path == None:
return None
data_path = [x for x in data_path if x is not None]
if len(data_path) == 0:
return None
if 'collate_fn' not in config['loader'] or config['loader'][
'collate_fn'] is None or len(config['loader']['collate_fn']) == 0:
config['loader']['collate_fn'] = None
else:
config['loader']['collate_fn'] = eval(config['loader']['collate_fn'])()
_dataset = get_dataset(
data_path=data_path,
module_name=dataset_name,
transform=img_transfroms,
dataset_args=dataset_args)
sampler = None
if distributed:
# 3使用DistributedSampler
batch_sampler = DistributedBatchSampler(
dataset=_dataset,
batch_size=config['loader'].pop('batch_size'),
shuffle=config['loader'].pop('shuffle'))
else:
batch_sampler = BatchSampler(
dataset=_dataset,
batch_size=config['loader'].pop('batch_size'),
shuffle=config['loader'].pop('shuffle'))
loader = DataLoader(
dataset=_dataset, batch_sampler=batch_sampler, **config['loader'])
return loader