你的位置:首页 > 信息动态 > 新闻中心
信息动态
联系我们

torchvision

2021/12/18 15:46:09

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

1. torchvision.datasets

torchvision.datasets中包含了以下数据集

  • MNIST
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10
__all__ = ('LSUN', 'LSUNClass',
           'ImageFolder', 'DatasetFolder', 'FakeData',
           'CocoCaptions', 'CocoDetection',
           'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST',
           'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
           'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
           'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
           'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
           'USPS', 'Kinetics400', 'HMDB51', 'UCF101')

例如,我们可以通过datasets.CIFAR10获得一个CIFAR10的数据集类对象。例如:

dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)

除此之外,datasets下的Datasets类都是torch.utils.data.Dataset的子类,所以,这些类我们都可以直接拿来用,它返回的是torch.utils.data.Dataset的子类对象。最常用的类是datasets.ImageFolder(DatasetFolder),其中DatasetFolder类间接继承了torch.utils.data.Dataset

class ImageFolder(DatasetFolder):

    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader, is_valid_file=None):
        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
                                          transform=transform,
                                          target_transform=target_transform,
                                          is_valid_file=is_valid_file)
        self.imgs = self.samples

参数说明

  • root:数据集所在的文件夹
  • transform:一个函数,原始图片作为输入,返回一个转换后的图片。用torchvision.transforms产生。
  • target_transform:一个函数,输入为target,输出对其的转换。例子,输入的是图片标注的string,输出为word的索引。

2. torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。

from .alexnet import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import *
from .googlenet import *
from .mobilenet import *
from .mnasnet import *
from .shufflenetv2 import *

你可以使用随机初始化的权重来创建这些模型。

import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

也可以使用预训练的模型,只要设置pretrained=True即可。

import torchvision.models as models
#pretrained=True就可以使用预训练的模型
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)

3. torchvision.transforms

这是一个图像转换工具。

3.1 对PIL.Image进行变换

class torchvision.transforms.Compose(transforms)

将多个transform组合起来使用。

transforms: 由transform构成的列表. 例子:

transforms.Compose([
     transforms.CenterCrop(10),
     transforms.ToTensor(),
 ])

### class torchvision.transforms.Scale(size, interpolation=2)

将输入的`PIL.Image`重新改变大小成给定的`size`,`size`是最小边的边长。举个例子,如果原图的`height>width`,那么改变大小后的图片大小是`(size*height/width, size)`。
**用例:**

from torchvision import transforms
from PIL import Image
crop = transforms.Scale(12)
img = Image.open('test.jpg')

print(type(img))
print(img.size)

croped_img=crop(img)
print(type(croped_img))
print(croped_img.size)

结果

<class 'PIL.PngImagePlugin.PngImageFile'>
(10, 10)
<class 'PIL.Image.Image'>
(12, 12)

3.2 对Tensor进行变换

3.3 Conversion Transforms

3.4 通用变换

4. torchvision.utils