torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。
1. torchvision.datasets
torchvision.datasets中包含了以下数据集
- MNIST
- COCO(用于图像标注和目标检测)(Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR10 and CIFAR100
- STL10
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST',
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
'USPS', 'Kinetics400', 'HMDB51', 'UCF101')
例如,我们可以通过datasets.CIFAR10获得一个CIFAR10的数据集类对象。例如:
dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)
除此之外,datasets下的Datasets类都是torch.utils.data.Dataset的子类,所以,这些类我们都可以直接拿来用,它返回的是torch.utils.data.Dataset的子类对象。最常用的类是datasets.ImageFolder(DatasetFolder),其中DatasetFolder类间接继承了torch.utils.data.Dataset。
class ImageFolder(DatasetFolder):
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader, is_valid_file=None):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples
参数说明:
- root:数据集所在的文件夹
- transform:一个函数,原始图片作为输入,返回一个转换后的图片。用
torchvision.transforms产生。 - target_transform:一个函数,输入为
target,输出对其的转换。例子,输入的是图片标注的string,输出为word的索引。
2. torchvision.models
torchvision.models模块的 子模块中包含以下模型结构。
from .alexnet import *
from .resnet import *
from .vgg import *
from .squeezenet import *
from .inception import *
from .densenet import *
from .googlenet import *
from .mobilenet import *
from .mnasnet import *
from .shufflenetv2 import *
你可以使用随机初始化的权重来创建这些模型。
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()
也可以使用预训练的模型,只要设置pretrained=True即可。
import torchvision.models as models
#pretrained=True就可以使用预训练的模型
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
3. torchvision.transforms
这是一个图像转换工具。
3.1 对PIL.Image进行变换
class torchvision.transforms.Compose(transforms)
将多个transform组合起来使用。
transforms: 由transform构成的列表. 例子:
transforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
])
### class torchvision.transforms.Scale(size, interpolation=2)
将输入的`PIL.Image`重新改变大小成给定的`size`,`size`是最小边的边长。举个例子,如果原图的`height>width`,那么改变大小后的图片大小是`(size*height/width, size)`。
**用例:**
from torchvision import transforms
from PIL import Image
crop = transforms.Scale(12)
img = Image.open('test.jpg')
print(type(img))
print(img.size)
croped_img=crop(img)
print(type(croped_img))
print(croped_img.size)
结果:
<class 'PIL.PngImagePlugin.PngImageFile'>
(10, 10)
<class 'PIL.Image.Image'>
(12, 12)
