你的位置:首页 > 信息动态 > 新闻中心
信息动态
联系我们

数据加载——torch.utils.data

2021/12/27 15:28:12

 官方文档

from torch.utils.data import DataLoader


DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

参数含义:

关键是dataset如何构建

dataset有两种形式:1、映射式(map-style),一种是迭代式(iterable-style)

1、映射式(map-style)

map_style数据集实现了__getitem__()和__len__()协议,并表示从(可能是非整数的)索引/键到数据样本的映射。

exmaple:用coco数据集举例

import PIL.Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from nets.nets_utility import *
import torchvision.transforms.functional as F

class COCODataset(Dataset):

    # 数据集参数初始化阶段
    def __init__(self, input_dir, crop_size=256, transform=None, need_crop=False, need_augment=False):        
        self._images_basename = os.listdir(input_dir)    # 将数据集路径下的图片名称存储到一个list当中
        if '.ipynb_checkpoints' in self._images_basename:
            self._images_basename.remove('.ipynb_checkpoints')
        self._images_address = [os.path.join(input_dir, item) for item in sorted(self._images_basename)]    # 每张图片的具体位置
        
        # 图片预处理的参数配置
        self._crop_size = crop_size    # 将图片调整为256*256
        self._transform = transform
        self._origin_transform = transforms.Compose([
            transforms.ToTensor()
        ])        
        self._need_crop = need_crop    # 是否需要调整
        self._need_augment = need_augment    # 是否要放大图片
    # 根据索引获取单个数据点
    def __len__(self):
        return len(self._images_address)     # list中有多少张图片
    # 获取数据集总体样本数量
    def __getitem__(self, idx):
        image = cv2.imread(self._images_address[idx], 0) / 255.0        # 0是灰度图,并且除以255进行归一化
        '''图片的预处理'''
        image = cv2.resize(image, (256, 256))  # 调整训练图像尺寸
        if self._need_crop:
            roi_image_np = self._random_crop(image)
        else:
            roi_image_np = image
        roi_image_pil = self._rand_augment(roi_image_np)
        if self._transform is not None:
            roi_image_tensor = self._transform(roi_image_pil)
        else:
            roi_image_tensor = self._origin_transform(roi_image_pil)
        return roi_image_tensor

    def _rand_augment(self, image):
        image_pil = PIL.Image.fromarray(image.astype(np.float32))
        if self._need_augment:
            image_pil = self._rand_horizontal_flip(image_pil)
        return image_pil


    '''下面都是图像预处理的一些手段'''
    def _rand_rotated(self, image_pil):
        rotate_angle = random.choice([0, 90, 180, 270])
        image_pil = F.rotate(image_pil, rotate_angle, expand=True)
        return image_pil

    def _rand_horizontal_flip(self, image_pil):
        # 0.5的概率水平翻转
        if random.random() < 0.5:
            image_pil = F.hflip(image_pil)
        return image_pil

    def _rand_vertical_flip(self, image_pil):
        # 0.5的概率水平翻转
        if random.random() < 0.5:
            image_pil = F.vflip(image_pil)
        return image_pil

    def _random_crop(self, image_np):
        h, w = image_np.shape[:2]
        start_row = random.randint(0, h - self._crop_size)
        start_col = random.randint(0, w - self._crop_size)
        roi_image_np = image_np[start_row: start_row + self._crop_size, start_col: start_col + self._crop_size]
        return roi_image_np

2、迭代式(iterable-style)

这个我没用过,所以只把官方链接放在这里