咚咚
V1
2022/02/27阅读:32主题:默认主题
<pytorch系列2>:Transforms
有时候数据初始形式与算法所需要的不相匹配。所以需要使用Transforms对数据进行一些操作,使其适合于训练
所有 TorchVision 自带的dataset类都有两个参数
-
用于修改输入特征的transform
-
用于修改标签的 target _ transform
Torchvision.transforms 模块提供了几种常用的变换
下面以torch自带的FashionMNIST进行说明
FashionMNIST 输入图像特征是 PIL 图像格式,标签是整数。
训练时,需要将图像特性转化为规范化张量,并将标签转化为one hot编码张量。
为了完成这些转换,使用了Torchvision.transforms 模块提供的 ToTensor 和 Lambda
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
ToTensor()
ToTensor将PIL图像或NumPy ndarray转换为浮点张量。并将图像的像素值缩放到**[0,1]**范围内
Lambda Transforms
Lambda transforms会将用户定义的 Lambda 函数应用到数据集中。
上述案例中,我们定义一个函数将整数转换为一个one hot的张量。它首先创建一个大小为10的零张量(标签类别数量) ,然后调用 scatter _,该调用在类别标签y索引上赋值为 1
上述对pytorch自带的Transforms进行讲述,后续会进行自定义Transforms
作者介绍
咚咚
V1
哈尔滨工业大学-计算机视觉