官方文档
from torch.utils.data import DataLoader
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
参数含义:
关键是dataset如何构建
dataset有两种形式:1、映射式(map-style),一种是迭代式(iterable-style)
1、映射式(map-style)
map_style数据集实现了__getitem__()和__len__()协议,并表示从(可能是非整数的)索引/键到数据样本的映射。
exmaple:用coco数据集举例
import PIL.Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from nets.nets_utility import *
import torchvision.transforms.functional as F
class COCODataset(Dataset):
# 数据集参数初始化阶段
def __init__(self, input_dir, crop_size=256, transform=None, need_crop=False, need_augment=False):
self._images_basename = os.listdir(input_dir) # 将数据集路径下的图片名称存储到一个list当中
if '.ipynb_checkpoints' in self._images_basename:
self._images_basename.remove('.ipynb_checkpoints')
self._images_address = [os.path.join(input_dir, item) for item in sorted(self._images_basename)] # 每张图片的具体位置
# 图片预处理的参数配置
self._crop_size = crop_size # 将图片调整为256*256
self._transform = transform
self._origin_transform = transforms.Compose([
transforms.ToTensor()
])
self._need_crop = need_crop # 是否需要调整
self._need_augment = need_augment # 是否要放大图片
# 根据索引获取单个数据点
def __len__(self):
return len(self._images_address) # list中有多少张图片
# 获取数据集总体样本数量
def __getitem__(self, idx):
image = cv2.imread(self._images_address[idx], 0) / 255.0 # 0是灰度图,并且除以255进行归一化
'''图片的预处理'''
image = cv2.resize(image, (256, 256)) # 调整训练图像尺寸
if self._need_crop:
roi_image_np = self._random_crop(image)
else:
roi_image_np = image
roi_image_pil = self._rand_augment(roi_image_np)
if self._transform is not None:
roi_image_tensor = self._transform(roi_image_pil)
else:
roi_image_tensor = self._origin_transform(roi_image_pil)
return roi_image_tensor
def _rand_augment(self, image):
image_pil = PIL.Image.fromarray(image.astype(np.float32))
if self._need_augment:
image_pil = self._rand_horizontal_flip(image_pil)
return image_pil
'''下面都是图像预处理的一些手段'''
def _rand_rotated(self, image_pil):
rotate_angle = random.choice([0, 90, 180, 270])
image_pil = F.rotate(image_pil, rotate_angle, expand=True)
return image_pil
def _rand_horizontal_flip(self, image_pil):
# 0.5的概率水平翻转
if random.random() < 0.5:
image_pil = F.hflip(image_pil)
return image_pil
def _rand_vertical_flip(self, image_pil):
# 0.5的概率水平翻转
if random.random() < 0.5:
image_pil = F.vflip(image_pil)
return image_pil
def _random_crop(self, image_np):
h, w = image_np.shape[:2]
start_row = random.randint(0, h - self._crop_size)
start_col = random.randint(0, w - self._crop_size)
roi_image_np = image_np[start_row: start_row + self._crop_size, start_col: start_col + self._crop_size]
return roi_image_np
2、迭代式(iterable-style)
这个我没用过,所以只把官方链接放在这里