咚咚

V1

2022/02/12阅读:202主题:默认主题

MAE:自监督视觉预训练新模型(论文+代码_解析)

  1. 训练过程中,对图像patches进行部分mask,使用编码器对可见图像patches进行编码,与masked 图像tokens一起送入解码器,以像素重建原始图像

  2. 经过预训练后,解码器被丢弃,编码器被应用于未损坏的图像(完整的patches集)进行下游识别任务

论文地址:https://arxiv.org/pdf/2111.06377.pdf

代码地址:https://github.com/facebookresearch/mae

Masked Autoencoders Are Scalable Vision Learners

推荐大家关注咚咚学AI公众号,会更新最新Cv论文和AI基本知识


摘要

引入主题 掩码自动编码器(MAE)是一种可扩展的计算机视觉自监督学习器
论文方法 mask输入图像的随机patches,并重建mask的像素。它基于两个核心设计。首先,开发了一个非对称的编码器-解码器体系结构,其中的编码器只在可见的patches子集上运行(没有被掩码),以及一个轻量级解码器,从潜在表示和掩码tokens重建原始图像

Approach

由上图1可知,该模型主要由mask、编码器和解码器组成,下面进行逐一分析

Masking

  1. 与Vit类似,将图像分割成规则的非重叠的小patch。
  2. 然后,我们对其进行采样,并对剩余的patches进行掩码(即删除)。其中抽样策略比较简单直接:随机抽样patches,不更换,遵循均匀分布,称之为“随机抽样”

使用高掩码率随机抽样很大程度上消除冗余,这会导致任务难度变大(见图2 - 4)。

均匀分布防止潜在的中心偏差(例如,对图像中心附近进行过多的mask)。

这种高度稀疏的输入有助于设计一种高效的编码器。

随机mask的代码如下:

 def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """

        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]  (N, L)
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove  (N , L) 返回排序后的值所对应原数据的下标
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]  # (N, len_keep)
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(11, D))  # (N, len_keep, D)

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)  # (N, L)
        mask[:, :len_keep] = 0  # 掩码的设置为0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)  # (N, L)

        return x_masked, mask, ids_restore

MAE encoder

编码器是一个ViT,但只适用于***可见的,未掩码patches***。

  1. 就像在标准ViT中一样,通过添加位置嵌入的线性投影来嵌入patches

  2. 然后通过一系列Transformer块来处理结果集。

    本文的编码器只在小的未掩码子集上工作。这允许训练非常大的编码器,而只占用一小部分的计算和内存。

解码器代码如下:

def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1-1)
        x = torch.cat((cls_tokens, x), dim=1)  # (N, len_keep+1, D)

        # apply Transformer blocks 
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

MAE decoder

MAE解码器的输入是由(i)编码的可见patches和(ii)掩码tokens组成的完整标记集。可参考图1。

每个掩码token都表示一个有待预测的缺失patch。

将位置嵌入添加到这个完整集合中的所有tokens中;如果不这样做,掩码tokens在图像中就不会有关于它们位置的信息。

解码器由一系列Transformer块组成

MAE解码器仅在预训练前用于执行图像重建任务。因此,解码器体系结构可以以独立于编码器设计的方式灵活设计。实验用非常小的解码器。例如,默认解码器与编码器相比,每个token的计算量小10%。采用这种非对称设计,所有tokens只由轻量级解码器处理,这大大减少了预训练时间

解码器代码如下:

 def forward_decoder(self, x, ids_restore): # (N, L+1, D), (N, L)
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)  # (N, L+1-(len_keep+1), D)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token  # (N, L, D)
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(11, x.shape[2]))  # unshuffle  (N, L, D)
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token (N, L+1, D)

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection  nn.Linear
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x  # (N, L, p*p*3)

Reconstruction target

MAE通过预测每个掩码块的像素值来重建输入。

解码器输出的每个元素都是表示一个patch的像素值向量。解码器的最后一层是线性投影,其输出通道数等于一个patch中的像素值数。解码器的输出被重塑以形成重构图像。

损失函数计算重建和原始图像在像素空间的均方误差(MSE)。只对掩码patches计算损失,类似于误码率

还研究了以每个掩码块的归一化像素值为重建目标。具体来说,计算一个patch中所有像素的均值和标准差,并用它们来归一化这个patch。在实验中,采用归一化像素作为重建目标,提高了图像的表示质量

损失函数代码:

def forward_loss(self, imgs, pred, mask):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove, 
        """

        target = self.patchify(imgs)  # (N, 3, H, W)->(N, L, p*p*3)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

Experiments

ImageNet识别任务

迁移学习

分类:

人工智能

标签:

图像处理

作者介绍

咚咚
V1

哈尔滨工业大学-计算机视觉