Jian-Guo Zhang, et.al “Discriminative Nearest Neighbor Few-Shot Intent Detection by Transferring Natural Language Inference”, EMNLP 2020
1 简介
这篇文章是做小样本、带 OOS(out of scope,就是选择题的 none of above) 的 intent detection 任务。方法非常简单巧妙:直接用 BERT 在 natural language inference(NLI)上预训练,然后用它做 pair-wise 分类。OOS 问题放到最后计算分数时在解决。
2 方法
2.1 别人的多分类方法
直接将要识别的句子(utterence,因为这个任务场景是在对话系统里,用户说一个 utterence,我们要判断他想干什么,所以这个句子就叫 utterence 了)输入 BERT,再过一层 linear 一层 softmax,就得到 utterence 属于每一类的概率,取最大的就是识别出的 intent 种类。对于有 OOS 的数据集,如果 softmax 算出来的最大一类的概率还小于某个阈值,就说他是 OOV。
2.2 DNNC
改成 pair-wise 输入 BERT,看这两个 utterence 是不是属于同一个 intent:$$h = \text{BERT}(u,e_{j,i})\in \mathbb{R}^d$$
这里 $u$ 和 $e_{j,i}$ 是两个 utterence,其中 $u$ 是我们想知道结果的那个,$e_{i,j}$ 是 $j$ 类的第 $i$ 个样本。直接拼接成 [[cls] u [sep] e_{j,i} [sep]]
输入 bert。如果原来是 N-way K-shot,那么这种 pair wise 的方法,就可以找到 $N\times K\times (K_1)$ 对正例和 $K^2\times N\times (N-1)$ 对负例,这就大大增加了训练样本数。
得到一对 utterence 的 embedding 之后,再过线性层、激活层,得到这对 utterence 的相似度 $S(u, e_{j,i})$$$S(u, e_{j,i}) = \sigma(W \cdot h + b) \in \mathbb{R}$$
上面是计算两个 utterence 的相似度,算法是先把 support set 里面两两都计算完,保存下来,然后就可以通过最近邻得到 $u$ 的类别:$$I(u) = \text{class}\bigg(\arg\max_{e_{i,j}\in E}S(u, e_{j,i})\bigg)$$
就是 $u$ 在样本集合 $E$ 中找到最近的,归为他的类。$\text{class}(e_{j,i})$ 就是“归类”函数。训练的 loss 就用交叉熵。
最后补充一句,上文说的所有 BERT 都是 Sentence BERT[1],并且已经在 NLI 任务上进行了充分的预训练。
2.3 NLI 预训练
输入方式也改成上面的两句话拼接输入 BERT,预训练时的数据,文章中使用了三个 NLI 的数据集。这里有亮点需要说明:
- 用这种方式预训练,得到的模型在 NLI 问题上不如 sota——毕竟是为了后面的 pair 分类准备的
- NLI 上预训练之后,也还是必须要 few-shot fintune 的,因为任务还是不太一样,还是得有个 transfer 的过程
2.4 缩减时间的方法
DNNC 需要对 $(N\times K)^2$ 对样本 embedding(原文 3.4 写的是 $N\times K$,感觉是笔误),太慢了,于是作者就想在所有样本对里面,用粗略的方法找到最接近的 $k$ 对样本,再去用 DNNC 的方法做。
所以作者又用另一个 Sentence BERT,先挨个 embedding 每个句子,然后计算句子 embedding 之间的余弦相似度,取相似度最高的 TOP k 对,用这 k 对句子的集合 $E_k$ 代替上面公式里面所有样本对的集合 $E$。实验发现,$k=20$ 是最好的。我认为这是因为,实验是 5-shot 和 10-shot 嘛,loss 又是交叉熵,所以正负例样本数比较平均的时候才比较好(这个应该和交叉熵的超参无关吧?),20 个就能比较平均(猜的)。
3 实验结果
这篇文章真的是做了好多试验,很多实验方式我都是第一次见,很有意思!对于 OOS prf 的计算也非常有新意,且非常 make sense,不过对我以后工作没啥用,我就不写了。
- 有 OOS 的数据集,sota
- 即使没有 OOS 的数据集,用 50-shot 训练也能达到别人全部数据(100-shot)训练的效果,表明 NLI 预训练很有用
- DNNC 的预测数据最 confident

这张图就是把 softmax 最后得到属于每类的概率统计一下,看看在 [0,1] 的每个区间都有多少个,如果全都接近一或者零,那就说明这个分类器有很强的信心,否则如果所有分数都在中间差不多,就说明这个分类器没底气。这个图真有意思!
其他的实验结论我觉得就都有点“废话”,或者对我没什么启发,就不写了。
4 总结和疑惑
- pair wise BERT + 最近邻分类可以显著解决 OOS,这个理论依据是什么?

如图所示,a 是直接用 bert + softmax 多分类,b 是 pair wise 的 relation network 分类(也用 bert 做 embedding),c 是本文提出的 pair wise BERT(不在 NLI 预训练) + 最近邻分类,d 是 c 加上 NLI 预训练。
b 不能分开 OOS,可见识别 OOS 与 pair wise 训练无关,那么 OOS 识别就一定是最近邻的功劳,那他的理论依据是什么呢?
a 的 softmax 这种“概率小于阈值”来判断 OOS 的方式,在理论上是不太 make sense 的。因为如果最大概率还小于阈值,就表明 softmax 得到的各类分数都差不多,也就是说,在隐空间里,这个 utterence 和每个类的距离都差不多。因此这里并没有限定 OOS 样本在隐空间的位置,所以最后,OOS 的例子并没有能在隐空间里面自成一堆。
反观 c 和 d 的最近邻的方法,他要求 utterence 和最相似的类别的相似度小于阈值,这样才是 OOS。这其实就是在隐空间里对于 OOS 样本的位置有了一个限制:必须要距离其他所有类别的簇足够远才可以。这样好像也没有显式要求 OOS 的具体位置,但是至少要求让他们分开了。
这时我就想到一个问题:那有没有必要、有没有可能,显式要求 OOS 样本在隐空间的具体位置,即显式地让他们成为一簇或者几簇呢?似乎是没有必要的。
- 如果最近邻有很好的识别 OOS 能力,那么 prototypical networks 应该也能有不错的识别 OOS 的能力。
但是作者在讨论部分说,他用了 Prototypical networks,但是不好:We also considered several ideas from prototypical networks (Sun et al., 2019), but those did not outperform our Emb-kNN baseline. These results indicate that deep self-attention is the key to the nearest neighbor approach with OOS detection.
他这句话中的 deep self-attention,如果我没理解错的话,就只是 BERT 里面的结构(在文中搜索 deep self-attention,他原文是 deep self-attention in BERT)。
首先,我认为他说的是错的,b 里面不也有 BERT,但是没能分开 OOS。我认为能让 OOS 在隐空间分得那么开,如前文所述,是最近邻的功劳。
那为什么作者用 Prototypical networks 不好呢?我觉得是因为,prototypical networks 必须显式地去找 prototype,这就导致无法用那种 pair wise 的方法,样本不够,可能就不太理想。唉,因此 prototype 和 pair-wise 就是矛盾的吗?
- 这种从多分类降到二分类的方法,尤其是直接暴力拼上输入 bert,我觉得在所有 NLP 的问题都是适用的。这种方法,一来,让样本数量平方,二来,直接
[[cls] sentence1 [sep] sentence2 [sep]]
这样输入 BERT,也非常简单。这在小样本的情境下,我认为对性能的提升也是普适的(还需要实验验证)。那么这种方法至少能用在关系、句子分类问题上吧?是不是用这种方法之后,小样本也能 fintune BERT 了呢? - 对于分类 confidence 的展示,值得学习一下。比如这个应该也能来证明“我们的 prototype 比其它模型的 prototype 好”,结合 t-SNE 的具体图片,简直完美。
5 TODO
[1] Nils Reimers and Iryna Gurevych. 2019. Sentence- BERT: Sentence Embeddings using Siamese BERT- Networks. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Process- ing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 3982–3992.