Sung, Flood, et al. “Learning to Compare: Relation Network for Few-Shot Learning.” CVPR 2018
1 简介
Relation Network(以下简称 RN)是 MN、PN 的改进版,把 MN 和 PN 固定的距离度量函数(metric)(余弦距离 or 欧氏距离)变成可以用神经网络学习的非线性函数。因此,对于第一步得到的 embedding,如果线性不可分,RN 的性能将比 PN MN 好得多。(这点我并不同意,后面会说)
2 算法
RN 的结构如图所示:

可以分成两大步:第一步 embedding,学习方式与 MN 相同。第二步用 relation module 决定 query 属于哪一类。作者把它看做是 PN 和 MN 的 metric 的非线性拓展。这点是事实,我同意。下面我就对 1-shot 和 0-shot 分别详细解释模型结构。
2.1 1-shot
令 RN 两部分 embedding module 和 relation module 所拟合的函数分别为 $f_\varphi$ 和 $g_\phi$。图中左边是 support set $\{x_i\}$,下面是 query $x_j$,先通过 embedding module 得到 $f_\varphi(x_i)$ 和 $f_\varphi(x_j)$,然后把 $f_\varphi(x_j)$ 复制 $N$ 份(N-way),和 support 的 embedding 拼到一起(用 $\mathcal{C}$ 表示),用 $\mathcal{C}(f_\varphi(x_i),f_\varphi(x_j))$ 输入 relation module,输出这一对儿属于一类的概率:$$r_{i,j} = g_{\phi}(\mathcal{C}(f_\varphi(x_i),f_\varphi(x_j)))$$
2.2 K-shot
只需要在拼接的时候,把这一类的所有 shot 和 query 共 k+1 个向量拼到一起,作为这一类 relation module 的输入。
2.3 0-shot
我以前一直以为 0-shot 就是聚类,看过这篇文章才终于明白:0-shot 是要预先给出每一类的 semantic class embedding vector $v_c$ 作为 support set。那么 query 属于 $v_c$ 这一类的概率为:$$r_{i,j} = g_{\phi}(\mathcal{C}(f_{\varphi 1}(v_c),f_{\varphi 2}(x_j)))$$
具体实现里,作者就直接拿 ImageNet 上用 ResNet 训练好的,这一类的 feature map 作为 $f_{\varphi 1}(v_c)$。
2.4 目标函数
是均方根误差函数,若属于一类,则令其概率接近 1,反之 0:$$\phi,\varphi \leftarrow \mathop{\arg\min}\limits_{\phi,\varphi}\sum_{i=1}^m\sum_{j=1}^n(r_{i,j} – \textbf{1}(y_i==y_j))^2$$
3 实验分析
1、 作者在 Omniglot、miniImageNet 进行了 1-shot 和 k-shot 实验,75% 是 SOTA,margin 也都不大,非常一般。
2、 和 PN 在 embedding 部分网络结构完全相同的情况下,RN 可以用少得多的 way 和 query 得到差不多的结果。作者没有分析原因,我斗胆猜测一下:
首先说 way 的影响。之前 PN 的文章里,为什么 1-shot 时 30-way 最好,5-shot 时 20-way 最好?我认为和聚类一样,这也是可以通过计算 embedding 的轮廓系数得到“分几类”最好的;换句话说,x-way 的数值就和聚类中簇数一样,对性能的确会有影响。
那么为什么不同 shot 对应最好的 way 数不同呢?这个我无法特别严谨地解释,感觉就是和数据本身的分布有关。但是这所谓“本身”的分布也是 embedding 之后的了,因此这个更大程度上是由 embedding 的函数决定的。那么,在 embedding 的部分,是不是可以在目标函数里加一些东西,让 embedding 出来的分布更“好分”呢?这又回到了我 Siamese Net 中的 TODO1。
当然,对于 PN 的最好 way 数,如果要抛弃 embedding 的影响,也能做实验来验证:把每次实验中间得到的 embedding 进行聚类,算轮廓系数,轮廓系数最大时,和 PN 的 30-way、20-way 应该是一致的。这个实验以后再说,这里先放一个【#TODO1】
那为什么 RN 可以用更少的 way 达到 PN 差不多的效果呢?我感觉这里还需要再做实验。RN 和 PN embedding module 结构相同,但是训练目标不同,因此得到的 embedding 分布也不一定相同,我觉得应该先画个图看看 embedding 分布是否一样。
如果 embedding 一样,那就是因为 NN 的分类器能力更强,所以 way 数的影响就不大了?作者怎么连这个实验都不做!那如果 embedding 不一样,那很可能就是上文说过的,embedding 导致分布不同,轮廓系数自然不同。
再说 query 数目的影响。先解释一下,这里说的 query 数的不同,是在同一个 episode 里面的,比如 PN 是 5-shot 5-query,RN 是 5-shot 1-query。query 的数量我理解就是更新参数的次数?所以 RN 收敛更快?这个也是有道理的。不过,还得做一个 RN 在 不同 query 数的实验才好。
最后,既然 PN 里面 30-way 5-query 最好,那为啥这个 RN 的作者不做一下 30-way 的实验,刷一下 SOTA 呢?还是因为在 RN 中,way 和 query 数量对结果都没什么关系了?
以上这么多该做的实验,我就先写在这里吧,不自己做,原因有三:① QQ飞车这个赛季马上结束了,我得上白金 ② 并不重要 ③ 说不定我读的下一篇文章里,就有人分析这些,并有了结论呢
3、 在 0-shot 问题上,RN 性能也只能说一般,10-way 不如 SOTA,50-way 是 SOTA。
4、 对于 relation module 的解释:
作者觉得 RN 的度量函数非线性了,就把 RN 看做是 PN 和 MN 的增强。觉得 RN 好,就好在 RN 可以应对线性不可分的情况,PN 和 MN 只能干瞪眼。作者还特意画出一个线性不可分的数据集让三者的 metric 部分来分类,当然 RN 要好得多。
但我觉得这是一种自欺欺人的做法。如果 embedding 之后的数据分布还那么狂,那要 embedding 还有什么用呢?的确,由于数据少,因此 embedding 网络不能很深,拟合的函数也就不会特别“非线性”,但是作者搞一个螺旋丸的例子,不是在搞笑吗?
我认为,RN 相比于 PN MN,实际上是一种简化。PN 和 MN 是 embedding + 测距,而 RN 实际上整体就可以看做是 embedding,只是最后要得到的是概率,中间加个 concatenation 罢了,实际上就是相当于去掉了 PN 中 测距的部分。并且在学习过程中,embedding 和 relation module 是在一起更新参数的,这比 PN 前面学习,后面定死,当然要好。RN 的好,在我看来,还是验证了我的 intuition:越简单的结构往往越好。Less is more 简直要成为我的座右铭了哈哈哈哈哈
4 相关
作者在这部分写的的确不错,他把 FSL 分成了基于度量学习、基于 RNN 记忆的方法、基于 finetune 三类方法。分别讲了一下。不过另外两类我还没有看论文,我还看不懂。
5 总结
RN 和 PN 相比,砍掉了测距,更好。
6 TODO
- embedding 聚类观察轮廓系数,看最好的类别数;并找到理论上的解释
- 3.2 的一大堆实验