咚咚

V1

2022/02/27阅读:100主题:默认主题

CosineAnnealingLR 代码解析与公式推导

微信公众号:咚咚学AI

CosineAnnealingLR是一种学习率scheduler决策

概述

其pytorch的CosineAnnealingLR的使用是

torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=- 1, verbose=False)

由上可以看出其主要参数为

  1. optimizer:为了修改其中的参数学习率,提供初始学习率
  2. T_max : 整个训练过程中的cosine循环次数
  3. eta_min:最小学习率,默认为0
  4. last_epoch:上一次epoch的索引,便于计算当前的学习率,默认为-1

其中最主要的修改参数为T_max和eta_min,首先通过直接修改这两个值来观察整个学习率变化曲线

令初始学习率为1.0,整个epoch为100

  1. 下图中T_max=10, eta_min=0.
  1. 下图中T_max=20, eta_min=0.5

由上两图可以看出,T_max可以看做coisne函数的半个周期长度,eta_min就是表示最小学习率

代码

接下来我们看看其代码实现,便于我们理解以及后续自己修改

学习率类型主要都是继承**_LRScheduler**类型,具体可以查看之前的文章(https://zhuanlan.zhihu.com/p/469323798)

从之前文章可知,获取学习率的方式有get_lr()和_get_closed_form_lr()两个函数

其中get_lr()多是根据上一次的学习率进行迭代计算的

_get_closed_form_lr()是根据当前epoch直接计算的

_get_closed_form_lr()

def _get_closed_form_lr(self):
        return [self.eta_min + (base_lr - self.eta_min) *
                (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
                for base_lr in self.base_lrs]

这个代码表达的公式1为

其中, 就是上文的eta_min,最小学习率

表示初始学习率,也是最大学习率

表示当前的epoch

大小范围为-1~1

大小范围为0~2

前面再除以2,使其大小范围为0~1

再乘以 使其大小范围为0~

最后加上 使其大小范围为 ~

最终的函数曲线就可以想象的出来了,与上述两个图保持一致

get_lr()

def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'for group in self.optimizer.param_groups]
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            return [group['lr'] + (base_lr - self.eta_min) *
                    (1 - math.cos(math.pi / self.T_max)) / 2
                    for base_lr, group in
                    zip(self.base_lrs, self.optimizer.param_groups)]
        return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
                (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
                (group['lr'] - self.eta_min) + self.eta_min
                for group in self.optimizer.param_groups]

get_lr()函数就相对较为复杂,咱们逐步分析

其中根据 分成两个情况考虑,逐个考虑

其中代码表示的公式2如下

上文已经表述,get_lr()函数是根据上一个学习率更新当前学习率的

所以需要知道两个学习率之间差别,由公式1可知 公式3

公式4

将以上公式3 4进行比值就能推导出公式2

但是

这个公式存在一个问题,分母可能为0,当 等于2k+1(奇数)的时候,这个公式就不成立了,所以要设立判断条件,同时也引出了第二种情况

代入公式3

得到公式5

同理,将 代入公式4

得到公式6

根据公式5 6就能得到

与代码一致

分类:

人工智能

标签:

深度学习

作者介绍

咚咚
V1

哈尔滨工业大学-计算机视觉