用BERT做语义相似度搜索太耗时?试试Sentence-BERT!
📖阅读时长:19分钟
🕙发布时间:2025-02-13
近日热文:全网最全的神经网络数学原理(代码和公式)直观解释
欢迎关注知乎和公众号的专栏内容
LLM架构专栏
知乎LLM专栏
知乎【柏企】
公众号【柏企科技说】【柏企阅文】
BERT和RoBERTa要求将两个句子都输入到网络中,这会带来巨大的计算开销:在10,000个句子的集合中找到最相似的一对,使用BERT大约需要进行5000万次推理计算(约65小时) 。BERT的架构使其不适合语义相似度搜索,也不适合聚类等无监督任务。
Sentence-BERT(SBERT)提出了对预训练的BERT网络的一种改进,它使用连体(siamese)和三元组(triplet)网络结构来生成具有语义意义的句子嵌入,这些嵌入可以使用余弦相似度进行比较。这使得找到最相似句子对的时间从使用BERT / RoBERTa的65小时减少到使用SBERT的大约5秒,同时保持了与BERT相同的准确率。
架构
SBERT在BERT / RoBERTa的输出上添加了一个池化操作,以得到固定大小的句子嵌入。实验了三种池化策略:
- 使用CLS标记的输出。
- 计算所有输出向量的平均值(MEAN策略)。
- 计算输出向量的时间维度上的最大值(MAX策略)。
默认配置是MEAN。
为了对BERT / RoBERTa进行微调,创建了连体和三元组网络来更新权重,以便生成的句子嵌入具有语义意义,并且可以用余弦相似度进行比较。
分类目标函数
句子嵌入向量$u$和$v$与按元素的差值$|u−v|$连接起来,再与可训练权重$W_t$相乘,并优化交叉熵损失。
回归目标函数
计算两个句子嵌入向量$u$和$v$之间的余弦相似度,并使用均方误差损失作为目标函数。
三元组目标函数
给定一个锚点句子$a$、一个正样本句子$p$和一个负样本句子$n$,三元组损失调整网络,使得$a$和$p$之间的距离小于$a$和$n$之间的距离。
训练和评估
SBERT在SNLI和Multi-Genre NLI数据集的组合上进行训练。SNLI是一个包含57万个句子对的数据集,标注了矛盾(contradiction)、蕴含(entailment)和中立(neutral)标签。MultiNLI包含43万个句子对,涵盖了一系列口语和书面文本的体裁。
SBERT的性能是针对常见的语义文本相似度(STS)任务进行评估的。
使用余弦相似度来比较两个句子嵌入之间的相似度。
论文
Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks 1908.10084
## 推荐阅读
1. DeepSeek-R1的顿悟时刻是如何出现的? 背后的数学原理
2. 微调 DeepSeek LLM:使用监督微调(SFT)与 Hugging Face 数据
3. 使用 DeepSeek-R1 等推理模型将 RAG 转换为 RAT
4. DeepSeek R1:了解GRPO和多阶段训练
5. 深度探索:DeepSeek-R1 如何从零开始训练
6. DeepSeek 发布 Janus Pro 7B 多模态模型,免费又强大!
本文由mdnice多平台发布
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。