时不我与

V1

2022/02/27阅读:41主题:橙心

2.6

日期: 2.6 - 2.7

论文题目: Deep CORAL: Correlation Alignment for Deep Domain Adaptation

论文地址:https://arxiv.org/pdf/1607.01719.pdf

1. 解决的问题

深度神经网络可以在大规模的标注数据中学校到特征,但是输入数据分布不同的时候泛化不是很好。因此提出了domain adaptation来弥补性能。本文针对target domain没有标注数据情况,对CORAL进行了改进。

2. CORAL

CORAL方法用线性变换方法将源域和目标域分布的二阶统计特征进行对齐。对于无监督域适应效果很好。问题出在依赖的是线性变换,而且不是端到端训练。训练分为两步,首先提取特征,应用变换,然后训练SVM分类。

3. 主要贡献

对CORAL算法扩展,使用非线性变换。将其应用到深度网络中,对源域和目标域的CORAL loss优化到最小。非线性变换更强大,并且可以与CNN无缝对接。

4. 优点

比DDC更强大,比DAN优化起来更容易,可以无缝集成到CNN结构中。

5. 网络结构

和其他深度自适应模型类似,网络对source和target domain的数据一起训练,backbone部分共享参数。source数据的输出与source label进行监督训练,得到分类的loss。而target数据的输出由于没有标注数据进行监督训练,因此要和source进行适应,计算CORAL loss。最终的目的是要将分类loss和CORAL loss共同优化到最小,即source的分类更精确,target的输出与source的分布更相似。

损失函数由两部分组成的原因有两个,要对两部分同时优化:

  1. 最小化分类loss本身会导致模型对源域过拟合,在目标域上性能很差
  2. 只单单对CORAL loss优化会恶化特征。网络会将source和target数据映射,如果映射到了同一个点,CORAL loss会变为0,这样的特征是不能构建强大的分类器的。

下面贴一下coral loss的pytorch代码,代码来源于王晋东博士的github

import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def CORAL(source, target):
    d = source.size(1)
    ns, nt = source.size(0), target.size(0)

    # source covariance
    tmp_s = torch.ones((1, ns)).to(DEVICE) @ source
    cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)

    # target covariance
    tmp_t = torch.ones((1, nt)).to(DEVICE) @ target
    ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)

    # frobenius norm
    loss = (cs - ct).pow(2).sum().sqrt()
    loss = loss / (4 * d * d)

    return loss

分类:

后端

标签:

后端

作者介绍

时不我与
V1