咚咚
2022/02/27阅读:100主题:默认主题
CosineAnnealingLR 代码解析与公式推导
微信公众号:咚咚学AI
CosineAnnealingLR是一种学习率scheduler决策
概述
其pytorch的CosineAnnealingLR的使用是
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=- 1, verbose=False)
由上可以看出其主要参数为
-
optimizer:为了修改其中的参数学习率,提供初始学习率 -
T_max : 整个训练过程中的cosine循环次数 -
eta_min:最小学习率,默认为0 -
last_epoch:上一次epoch的索引,便于计算当前的学习率,默认为-1
其中最主要的修改参数为T_max和eta_min,首先通过直接修改这两个值来观察整个学习率变化曲线
令初始学习率为1.0,整个epoch为100
-
下图中T_max=10, eta_min=0.

-
下图中T_max=20, eta_min=0.5

由上两图可以看出,T_max可以看做coisne函数的半个周期长度,eta_min就是表示最小学习率
代码
接下来我们看看其代码实现,便于我们理解以及后续自己修改
学习率类型主要都是继承**_LRScheduler**类型,具体可以查看之前的文章(https://zhuanlan.zhihu.com/p/469323798)
从之前文章可知,获取学习率的方式有get_lr()和_get_closed_form_lr()两个函数
其中get_lr()多是根据上一次的学习率进行迭代计算的
_get_closed_form_lr()是根据当前epoch直接计算的
_get_closed_form_lr()
def _get_closed_form_lr(self):
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
for base_lr in self.base_lrs]
这个代码表达的公式1为
其中, 就是上文的eta_min,最小学习率
表示初始学习率,也是最大学习率
表示当前的epoch
大小范围为-1~1
大小范围为0~2
前面再除以2,使其大小范围为0~1
再乘以 使其大小范围为0~
最后加上 使其大小范围为 ~
最终的函数曲线就可以想象的出来了,与上述两个图保持一致
get_lr()
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == 0:
return [group['lr'] for group in self.optimizer.param_groups]
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return [group['lr'] + (base_lr - self.eta_min) *
(1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in
zip(self.base_lrs, self.optimizer.param_groups)]
return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
(1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
(group['lr'] - self.eta_min) + self.eta_min
for group in self.optimizer.param_groups]
get_lr()函数就相对较为复杂,咱们逐步分析
其中根据 和 分成两个情况考虑,逐个考虑
其中代码表示的公式2如下
上文已经表述,get_lr()函数是根据上一个学习率更新当前学习率的
所以需要知道两个学习率之间差别,由公式1可知 公式3
公式4
将以上公式3 4进行比值就能推导出公式2
但是
这个公式存在一个问题,分母可能为0,当 等于2k+1(奇数)的时候,这个公式就不成立了,所以要设立判断条件,同时也引出了第二种情况
将 代入公式3
得到公式5
同理,将 代入公式4
得到公式6
根据公式5 6就能得到
与代码一致
作者介绍
咚咚
哈尔滨工业大学-计算机视觉