咚咚

V1

2022/02/27阅读:32主题:默认主题

<pytorch系列2>:Transforms

有时候数据初始形式与算法所需要的不相匹配。所以需要使用Transforms对数据进行一些操作,使其适合于训练

所有 TorchVision 自带的dataset类都有两个参数

  • 用于修改输入特征的transform

  • 用于修改标签的 target _ transform

    Torchvision.transforms 模块提供了几种常用的变换


下面以torch自带的FashionMNIST进行说明

FashionMNIST 输入图像特征是 PIL 图像格式,标签是整数。

训练时,需要将图像特性转化为规范化张量,并将标签转化为one hot编码张量。

为了完成这些转换,使用了Torchvision.transforms 模块提供的 ToTensor 和 Lambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

ToTensor将PIL图像或NumPy ndarray转换为浮点张量。并将图像的像素值缩放到**[0,1]**范围内

Lambda Transforms

Lambda transforms会将用户定义的 Lambda 函数应用到数据集中。

上述案例中,我们定义一个函数将整数转换为一个one hot的张量。它首先创建一个大小为10的零张量(标签类别数量) ,然后调用 scatter _,该调用在类别标签y索引上赋值为 1


上述对pytorch自带的Transforms进行讲述,后续会进行自定义Transforms

分类:

人工智能

标签:

深度学习

作者介绍

咚咚
V1

哈尔滨工业大学-计算机视觉