文章目录
摘要
BERT发表时在很多任务上(比如:文本语义相似度)都取得了当时最好的结果。但是在一些任务上,BERT计算性能较差。在使用中,BERT需要将两个句子拼接在一起,输入到模型中,这就引入了大量的计算预处理:比如在10000个句子中找到最相似的句子对要计算大约5000万次的模型推理(inference)计算。BERT的结构导致它不适合做语义相似度搜索任务以及其它一些类似非监督的聚类任务等。
这篇论文提出了一个新的模型:Sentence-BERT (SBERT),它在预训练的BERT基础之上,使用孪生网络(siamese network)和三元组网络(triplet network),对模型进行了微调。这样模型就可以生成句子的嵌入表示(sentence embedding),这些句子的嵌入表示可以直接用来计算句子间的余弦相似度。这样在前面提到的寻找最相似句子的任务中,就无需对每个句子对进行推理,大大减少了计算时间(65小时 $\to$ 5秒),并且不会降低准确率。
作者基于预训练的BERT和RoBERTa(可以看作BERT增强版)这两个模型进行了实验,表现超过了当时最好的句子嵌入方法。
背景
摘要中提到了BERT在一些基于句子对的任务上的一些缺点:对于有N个句子的任务来说,那么使用BERT来对每个句子对进行推理需要进行$O(N^2)$次,这会耗费大量的时间。
常用的解决句子聚类/语义搜索任务的问题的方法是进行句子嵌入(sentence embedding),也就是将句子映射到一个向量空间中,句子间的相似度可以用其在向量空间中的相似度来衡量。这样基于向量来计算速度就可以极大的加快。虽然BERT本身也支持一种嵌入的方法:将单句子输入到BERT中,然后使用$[CLS]$(标志句子起始的token)对应的词嵌入来表示整个句子的嵌入。这种方法在实际使用中,性能表现经常不行,很多情况下都比不上GloVe [2]。
详解
SBERT模型在BERT模型后面添加了一个池化层用于输出一个固定长度的向量。文中实验了三种类型的池化操作:
- CLS: 使用CLS token对应的词嵌入作为句子嵌入
- MEAN: 使用所有输出向量的均值
- MAX: 使用输出向量的最大值(在时间维度上计算)
为了对BERT模型做微调,SBERT引入了孪生网络和三元组网络用于训练,这样可以使得到的句子嵌入在向量空间中的距离可以反映句子间的相似性。
论文中实验了三种微调方式,分别对应了三个目标函数:
- 分类目标函数
- 回归目标函数
- 三元组目标函数
其中分类和回归目标函数对应使用了孪生网络结构,三元组目标函数使用三元组网络结构。
对于孪生网络和三元组网络,读者可以看我们之前写的文章。
网络结构
分类目标函数
分类目标函数为:
使用该目标函数进行微调的时候,使用了孪生网络结构,SBERT对应网络结构如下图所示:
图中左右两部分为相同的网络(结构、参数都一样)。其中$u, v$分别对应了句子1和句子2的嵌入表示。得到句子嵌入之后,将两个句子的嵌入$u, v$以及其差值$u-v$拼接到一起,然后输入到一个全连接层,最后使用$softmax$得到分类结果。文中使用了交叉墒$loss$函数对模型进行微调。
微调训练完成后,仅需要$BERT$层和$Pooling$层即可进行前向推理,以得到句子的嵌入表示。
回归目标函数
回归目标函数为$cos(u, v)$,并且使用均方差误差($mse$)作为$loss$函数进行微调训练:
对应的网络结构如下:
该网络结构与前一个基本一样,都需要得到两个句子的嵌入$u, v$,只不过此处使用了余弦相似度作为衡量两个句子的距离。微调完成后同样只需要保留$BERT + Pooling$层即可。
三元组目标函数
三元组目标函数对应的SBERT用于微调的网络结构用到了三元组网络,与孪生网络结构不同的是:孪生网络使用了两个完全一样的网络结构,而三元组网络使用了三个。三元组网络有三个输入$(a, p, n)$,其中句子$a$作为锚点,$p$表示与$a$同类型的句子,$n$表示与$a$为不同类型的句子。我们使用函数$f(\cdot)$表示$BERT+Pooling$两层的功能,显然$f(x)$表示句子$x$的嵌入向量表示。三元组网络对应的目标函数为:
网络结构如下:
三元组$loss$函数的目标是拉近$a/p$之间的距离,而拉远$a/n$之间的距离。$\epsilon$为一个超参数,表示希望$n$距离$a$的距离至少要比$p$与$a$的距离远$\epsilon$。
微调训练
数据集
文中训练SBERT使用了两个自然语言推理(NLI)数据集:
- SNLI [3]: 该数据集中包含了约570000个句子对,是一个三分类的数据集
- Multi-Genre NLI [4]: 该数据集包含了约430000个句子对
超参数
- Batch size: 16
- 优化器: Adam (2e-5的学习率)
- 训练时使用了10%数据对学习率进行线性$warm\ up$
- 默认池化方式为MEAN
实验结果
语义相似性任务
文中进行了多组实验,此处我们给出其中两个结果。
AFS 数据集(辩论观点相似度)
该数据集使用0~5六个等级衡量观点的相似度,其中数据集中包含了三个话题:
- 枪支控制
- 同性恋婚姻
- 死刑
文中使用了两种测试方式:
- 10层交叉验证
- 使用其中两个话题训练,另外一个话题测试
此处SBERT使用了上文的回归loss函数的模型进行了微调。余弦相似度被用来衡量句子间的相似的得分。得到句子相似度的分值后,实验结果计算了预测值与标签之间的相关系数,实验结果如下:
其中
- $r$为Pearson相关系数
- $\rho$为Spearman相关系数
在此处的两组实验中,我们可以看到SBERT的结果比不上仅仅使用BERT。文中给出的解释是:BERT得益于其中使用的注意力机制可以直接逐词对句子中的单词进行比较,而SBERT得到了一个句子的嵌入向量表示,对于相似观点及理由的样本就较难区分,因此效果并不好。作者表示,这个数据集中仅包含了三个主题,使用其中两个去训练,样本远远不够,因此,若想要达到与BERT差不多的效果,SBERT需要更多的数据。
维基百科句子对分类(属于同一段/不属于)
此处的模型使用了三元组网络进行微调,微调时每个epoch使用了1800000个三元组句子进行训练,测试集使用了从不同的文档中提取出的222957个三元组。对于一个三元组,如果模型给出的$a, p$之间的距离(相同段中的句子)比$a, n$之间的距离(不同段)小,那么该三元组标记为成功;否则失败。测试结果如上图所示。基于SBERT方法性能明显比较好。SBERT/SRoBERTa (base/large)分别是基于预训练的BERT/RoBERTa (base/large)进行微调的。
文本嵌入任务
文中使用SentEval[6]工具来测量SBERT输出的句子嵌入的性能表现,结果如下:
可以看到SBERT在大多数任务上都取得了较优的性能表现,平均表现也是最棒的。
结束语
本文提出的SBERT克服了BERT在一些任务上句子嵌入方面的缺点,在BERT基础上引入了孪生网络和三元组网络对BERT网络进行微调。SBERT在很多任务上都取得了最好的表现。在特定任务(比如:句子的层次聚类任务)上,计算性能方面远远优于BERT。
引用
[1] Reimers, Nils, and Iryna Gurevych. "Sentence-bert: Sentence embeddings using siamese bert-networks." arXiv preprint arXiv:1908.10084 (2019).
[2] Pennington, Jeffrey, Richard Socher, and Christopher D. Manning. "Glove: Global vectors for word representation." Proceedings of the 2014 conference on empirical methods in natural language processing (EMNLP). 2014.
[3] Bowman, Samuel R., et al. "A large annotated corpus for learning natural language inference." arXiv preprint arXiv:1508.05326 (2015).
[4] Williams, Adina, Nikita Nangia, and Samuel R. Bowman. "A broad-coverage challenge corpus for sentence understanding through inference." arXiv preprint arXiv:1704.05426 (2017).
[5] Dor, Liat Ein, et al. "Learning thematic similarity metric from article sections using triplet networks." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers). 2018.
[6] Conneau, Alexis, and Douwe Kiela. "Senteval: An evaluation toolkit for universal sentence representations." arXiv preprint arXiv:1803.05449 (2018).
更多推荐