论文笔记 – Optimization as a Model for Few-Shot Learning

Ravi, Sachin, and Hugo Larochelle. “Optimization as a Model for Few-Shot Learning.” ICLR 2017

这篇文章和 [4] 开创了用 LSTM 作为优化器,来 learning to learn 的方法的先河。梯度下降权重更新的表达式其实是 LSTM 里面 cell 的表达式的简化版本,于是就可以通过共享参数的 LSTM 来训练一个“可以训练出分类器的函数”的函数。这里的元学习,其目标是学习出一个参数初值或优化器,而不是像 pre-train 一样为了训练一个比较 generalize 的模型,这个详见 [2]。

阅读本文之前,建议先看 [1],复习 LSTM 的知识,看 [3] 复习 batch normalization 的知识。[2] 其实就是李宏毅老师在课堂上讲的 meta-learning LSTM 的结构,非常有助于理解这篇文章,必看!

1 简介

为什么基于梯度下降的算法在 FSL 不行?作者提出两个原因:第一,非凸问题时,收敛慢,小样本不能让它收敛;第二,每次在新的数据集训练的时候,都要随机初始化,从头再来,这也使得它不能很快收敛。迁移学习倒是可以缓解这个问题,但是如果迁移去的领域和训练的领域相差比较远,那效果还是不太行的。因此作者希望可以学到一个最好的参数初值,基于这个初值,用比较少的样本也能快速收敛。

整个模型的目标函数是评价“从这个初始参数训练出的模型的性能”的性能,这样得到的结果(初始参数),在任何任务上,都可以用小数据,非常快的收敛。

2 算法

2.1 算法原理

传统的梯度下降的参数更新公式为:$$\theta_t = \theta_{t-1} – \alpha \triangledown_{\theta_{t-1}}\mathcal{L}_t$$

很基础,不解释。而在 LSTM 中,cell 的更新函数为:$$c_t = f_t\odot c_{t-1} + i_t\odot \tilde{c}_t$$

其中 $\odot$ 为 hadamard product。我们发现,如果这个式子里 $f_t = 1,c_{t-1} = \theta_{t-1}, i_t = \alpha_t, \tilde{c}^t = -\triangledown_{\theta_{t-1}}\mathcal{L}_t$,那么这将和上面梯度下降的表达式完全相同。

因此,作者提出了 meta-learner LSTM,作为神经网络参数的学习规则。LSTM 的 cell 就是 learner 的参数。从 LSTM 的角度理解,那么更新参数的过程就是,先判断是不是要遗忘($f_t$)原始的参数 $c_{t-1}$,然后再计算这次在梯度方向更新的部分 $\tilde{c}_t = – \triangledown_{\theta_{t-1}}\mathcal{L}_t$求和即为更新的参数。下面就来定义输入门 $i_t$ 和遗忘门 $f_t$:$$i_t = \sigma(W_I\cdot[\triangledown_{\theta_{t-1}}\mathcal{L}_t,\mathcal{L}_t,\theta_{t-1},i_{t-1}]+b_I)$$ $$f_t = \sigma(W_F\cdot[\triangledown_{\theta_{t-1}}\mathcal{L}_t,\mathcal{L}_t,\theta_{t-1},f_{t-1}]+b_F)$$

这两个式子的含义就是,$i_t,f_t$ 都与当前参数值 $\theta_{t-1}$,当前梯度 $\triangledown_{\theta_{t-1}}\mathcal{L}_t$,当前 loss $\mathcal{L}_t$,以及当前的输入门 $i_{t-1}$ 或遗忘门 $f_{t-1}$ 有关。

输入门其实就相当于梯度下降的学习率(梯度前面的数字嘛)。而遗忘门则和梯度下降不同,遗忘门并不一直是 1,在比较不好的情况(loss 很大,但是梯度接近于 0)下,遗忘门可以变小,逃出这个困境。

2.2 参数共享

一个模型有很多很多参数,我们上面说的就是对于模型的一个参数的基于 LSTM 的优化过程。为了简化,作者就令所有参数都用同一个 LSTM 来优化。具体的操作就是,每次输入的都是一批参数,每次更新都取这一批的平均来更新。

另外,作者还说,在 normalize 的时候,因为每个参数都可能差很多,为了让 meta-learner 能使用(?),就使用 [4] 中的标准化方法:

$$ x\rightarrow \begin{cases} (\frac{log|x|}{p},\text{sgn}(x))& \text{if } |x|\ge e^{-p}\ (-1,e^px)& \text{otherwise} \end{cases} $$

这个式子我完全看不懂,还需要看 [4] 的原文【#TODO1】

2.3 训练

训练的方法与 matching network 一致,meta-train 和 meta-test 都是分成 episode 进行,每个 episode 包含 $D_{train}$(就是 FSL 领域的 support set,元学习领域叫 train) 和 $D_{test}$(query)。meta-train 的每个 episode 对应一个 meta-learner 的 training procedure,输出训练好的模型,并在 meta-test 上,训练、测试得到性能结果。

最后得到好的 meta-learner 是什么样呢?无论在哪个 support set 上,随便 train 一下,都可以得到非常好的结果。

我们训练的目标函数是一个 batch 里所有 episode 的 loss 之和。论文里有一张流程图,但并不是很清晰。这里还是推荐去听 [2]。

2.3.1 梯度独立假设

我们可以发现,$\mathcal{L}_t$ 和 $\triangledown_{\theta_{t-1}}\mathcal{L}_t$ 作为 LSTM cell 的两个输入,这俩是相关的。为了简单化,我们就忽略这俩的相关性。这个在 [2] 也有讲。

2.3.2 meta-learner 参数的初始化

参数主要就是两个,输入门 $i_t$ 和遗忘门 $f_t$。遗忘门(其实是记忆门)一开始很大,要接近 1,这样就可以让梯度快速传递。输入门一开始要很小,这样学习率一开始就会很小。(这里我没明白,为什么学习率要从很小开始?【#TODO2】)这样一来,meta-learner LSTM 一开始就和一个小学习率的梯度下降一样,这对于训练的稳定性很有帮助(不懂)。

2.4 batch normalization

这里先要看 [3],了解 batch normalization。batch normalization 在标准化的时候,把原来的方差和均值都记录下来,但是我们不希望在不同的 episode 测试的时候,这些均值方差还能共享。(我的理解:如果能共享,那么测试的时候可能也会学到一些东西)。所以就在 test 的 query 用 support set 的均值方差。(这我觉得是非常正常的思路,不知道为啥还特地写一大段)

3 实验结果

这个模型的具体结构,即原文的 5.0 部分,我没看懂,说是有两层 LSTM,第一层是正常的 LSTM,第二层是上述的 meta-LSTM,这个我完全没明白,两层都有什么用?【#TODO3】

  1. 在 Mini-ImageNet 上,性能和 matching network 差不多
  2. 如果不使用 episode 的方法训练(就是文中的 Baseline-finetune),效果要比最近邻还差,因为这会严重过拟合
  3. 作者对 meta-LSTM 的输入门 $i_t$ 和遗忘门 $f_t$ 这两个参数随训练的变化作图,然后作者就得出结论:meta-LSTM 对遗忘门学到了一个 decay 策略,在不同层之间一致。输入门好像没学到什么策略,这表明两个任务不一样,meta-optimizer 学到的也不一样(废话)

4 相关

MANN(马上看)、SN、MN(这俩都看过了)

5 总结

这篇文章是我第一次接触元学习领域,基础知识不牢,看得迷迷糊糊的,即使写到这里,其实也不是特别明白。以后得专门看一下元学习,然后再看一下这篇文章的姊妹篇 [4],然后再重新回味这篇文章。

这篇文章是我用新买的 magic keyboard 2 敲出来的,感觉手感还不如之前 100 块的 k380 T_T,多出来的小键盘好像也完全没有用,还占地方,鼠标不好移动,好亏 T_T 就是样子货唉

6 引用和拓展阅读

[1] 李宏毅机器学习(2017)-RNN 从 18:00 开始是 LSTM

[2] 李宏毅机器学习2019(国语)-Gradient descent as LSTM

[3] Batch Normalization原理与实战 – 天雨粟的文章 – 知乎

[4] Andrychowicz, Marcin, et al. “Learning to learn by gradient descent by gradient descent.” NeuralIPS 2016

7 TODO

  1. [4]
  2. 为什么学习率要从很小开始?
  3. 具体的模型结构

Leave a Comment

电子邮件地址不会被公开。