论文笔记 – Prototypical Networks for Few-shot Learning

Snell, Jake, Kevin Swersky, and Richard S. Zemel. “Prototypical Networks for Few-shot Learning.” NeuralIPS 2017

1 简介

Prototypical network 属于“embedding 最重要”的一类算法。它将 support set 的同一类样本的所有 embedding 的质心作为这一类的 prototype。query 就直接用自己的 embedding 去找最近的 prototype,就完成了分类。这个想法我也能想出来啊!而且这么简单的方法竟然要比 Matching Network 复杂的 FCE 还要优秀,这是我想不明白的。

这篇文章也是,比较哲学。真正的算法只讲了半页,更多都是在和不同的算法进行比较。我不是特别明白:如果是从其它算法得到的灵感,那写出这些算法,列出相似点,倒也无可厚非;可是作者这样和其它算法比较,有什么意义呢?是要让我们更好的学习和掌握吗?我真的摸不着头脑。

2 Prototypical Network (PN)

2.1 算法

左图是 PN 在 FSL 问题上的用法,用 support set 的 embedding 的中心作为 prototype,query 找到最近的来匹配。右图是 PN 在 Zero-shot 问题上的用法,作者说 prototype 来自这一类的 meta-data,直接 embedding 就是了,这里我没有理解,meta-data 是什么?

下面来说下具体的计算。接下来说的都是 PN 在 FSL 上的用法,Zero-shot 到 2.6 再讲。

计算 prototype $c_k$ 的方法为:$$c_k = \frac{1}{|S_k|}\sum_{(x_i,y_i)\in S_k}f_{\phi}(\mathrm{x}_i)$$

其中 $\mathrm{x}_i$ 是 support set 的样本,$S_k$ 是类别为 $k$ 的 support set 样本集合,$f_{\phi}()$ 是 embedding 函数。与 matching network 一样,整个模型除了 embedding,都是非参数的。那么,如何 embedding 呢?

学习 embedding 过程的目标函数:$$J(\phi) = -\log p_{\phi}(y=k|\mathrm{x})$$ $$p_\phi(y=k|\mathrm{x}) = \dfrac{\exp(-d(f_\phi(\mathrm{x}),c_k))}{\sum_{k’}\exp(-d(f_\phi(\mathrm{x}),c_{k’}))}\tag{1}$$

整体的含义就是对于每一类 $k$,让所有样本的 embedding 和这类的 prototype 的距离,归一化之后最近。优化的过程是 mini-batch SGD。训练过程与 matching network 相同,每个 epoch 都是在全部类别中挑选出来一些,再分成 support set 和 query,这样小样本训练。

2.2 PN Vs Mixture Density Estimation

为了理解这一节,建议先阅读 [1,2,3] 作为前提的基础知识。$\mathrm{z}$ 和 $\mathrm{z}’$ 的 bregman divergence 为: $$d_\varphi(\mathrm{z},\mathrm{z}’) = \varphi(\mathrm{z}) – \varphi(\mathrm{z}’) – (\mathrm{z}-\mathrm{z}’)^T\triangledown\varphi(\mathrm{z}’)$$

这个式子不解释了,就是定义。

那么对于所有指数族分布 $p_\phi(\mathrm{z}|\theta)$($\theta$ 是参数,$\phi$ 是为了区分,这个是分布函数,不是概率密度函数),都可以改写成一个对应的唯一的 bregman divergence 的形式:$$p_\psi(\mathrm{z}|\theta) = \exp\{\mathrm{z}^T\theta – \psi(\theta) – g_{\psi}(\mathrm{z})\} = \exp\{-d_\psi(\mathrm{z},\mu(\theta)) – g_{\varphi}(\mathrm{z})\}$$

前一个等号是指数族分布的定义,后一个是将 $d_\phi$ 代入。

现在考虑一个参数为 $\Gamma = \{\theta_k,\phi_k\}_{k=1}^K$ 的指数族混合模型(GMM 就属于这个),样本 $\mathrm{z}$ 的标签的分布为:$$p(\mathrm{z}|\Gamma) = \sum_{k=1}^K\pi_i p_\psi (\mathrm{z}|\theta_k) = \sum_{k=1}^K \pi_k \exp(-d_\psi(\mathrm{z}, \mu(\theta_k)) – g_\varphi(\mathrm{z}))$$

前一个等号是定义,后一个是将 $p_\psi(\mathrm{z}|\theta)$ 代入。

那么样本 $\mathrm{z}$ 标签为 $k$ 的概率就为:$$p(y=k|\mathrm{z}) = \frac{\pi_k\exp(-d_\varphi(\mathrm{z},\mu(\theta_k)))}{\sum_{k’}\pi_{k’}\exp(-d_\varphi(\mathrm{z},\mu(\theta_k)))}$$

这个跟式(1)除了一些参数,其它完全一致。因此作者说,PN 其实就是在用指数族函数做密度估计。那唯一需要考虑的,就是距离的度量,也就是 $d_\varphi$ 的选择。

小结一下,按照 matching network 里面 non-param 和 param 的思考角度,PN 只是加个距离度量函数的 MDE 罢了。

2.3 PN Vs Linear model

这一部分我感觉作者说的都是废话,不知道他想要表达什么(就是说一堆也没用哇)。

作者说,如果距离选择欧氏距离 $d(\mathrm{z},\mathrm{z}’) = \|\mathrm{z}-\mathrm{z}’\|^2$,那么式(1)所代表的模型就是一个线性模型。

式(1)中的 $-d(f_\phi(\mathrm{x}),c_k)$ 可以写为$$-\|f_{\phi}(\mathrm{x}) – c_k\|^2 = -f_\phi(\mathrm{x})^Tf_\phi(\mathrm{x}) + 2c_k^Tf_\phi(\mathrm{x}) – c_k^Tc_k$$

这个式子的第一项是常数(因为 embedding 肯定 normalize 了),后两项就可以写成 $x\mathrm{x} + b$ 的形式,就是一个神经网络的样子啦(这不废话么,不然你那什么优化)。

2.4 PN Vs Matching network

在 one-shot 的情况,PN 和 MN 是完全等价的。MN 相当于每个样本都是 prototype,PN 是所有一类样本的平均作为 prototype,因此 1-shot PN 与 MN 都是样本自己作为 prototype,完全等价。

作者就又开始发散了:那是一类有多个 prototype 好呢,还是只有一个 prototype 好呢?作者觉得自己做的最好。因为如果有固定的多个 prototype,那么这将进一步地将一类继续细分,就会把参数更新的过程打断了,⑧ 太行。 然后 MN 里面把 support 和 query 的 embedding 函数搞成不同的两个,PN 当然也能用,但是作者对此也是不屑一顾。

2.5 Zero-shot learning

作者这里才说 meta-data vector 是提前给好的。。。具体流程上文说过,不再说了。

3 实验结果

  1. Omniglot、miniImageNet 分类准确率都是 SOTA
  2. 距离使用欧氏距离要比余弦距离好得多,因为余弦距离不是 bregman divergence,就不满足 2.2 的那些推导,就没有密度估计的好处了?(我就顺着说下来的,我也不懂)(换个角度想,算 prototype 的时候是直接平均,那就是在每个维度都平均,因此算 $d$ 的时候也应该是这样的度量嘛)
  3. N-way K-shot,训练的时候 N 取多少比较好呢?作者试了一下,发现 20 左右最好,多了少了都会更差
  4. zero-shot 是 big-margin SOTA,说明 PN 领域迁移的能力很好(可是 meta-data 是给的啊?)

4 总结

  1. 方法很朴素,“我也能想出来”系列。
  2. 作者真的能力很强,能把各种联系到一起,这怎么也得是教授级别了。

5 引用和拓展阅读

[1] 如何理解Bregman divergence? – 覃含章的回答 – 知乎

[2] 高斯混合模型(GMM) – 戴文亮的文章 – 知乎

[3] 指数族分布

Leave a Comment

电子邮件地址不会被公开。