RAG中的双编码器与跨编码器模型
📖阅读时长:19分钟
🕙发布时间:2025-02-13
近日热文:全网最全的神经网络数学原理(代码和公式)直观解释
欢迎关注知乎和公众号的专栏内容
LLM架构专栏
知乎LLM专栏
知乎【柏企】
公众号【柏企科技说】【柏企阅文】
检索增强生成(RAG)是一个强大的框架,它结合了基于检索和基于生成的自然语言处理(NLP)任务方法。RAG不只是依赖生成模型,而是通过检索相关文档或段落来利用外部知识源,进而提高生成内容的准确性和质量。这种检索与生成的融合,让RAG能更有效地处理问答、摘要以及知识密集型生成等任务。
RAG通常会用到两个关键组件:
- 检索模块:从大型语料库中检索相关文档。
- 生成模块:像BART、T5或GPT这样的生成模型,利用检索到的文档生成最终输出。
有两种主要策略可以优化检索过程,即双编码器(Bi-encoders)和跨编码器(cross-encoders)。这些方法加强了查询和文档之间的交互,以便在将检索到的文档输入生成模型前,获得更好的检索结果或优化已检索的文档。
本文将探讨如何在RAG的背景下实现双编码器和跨编码器模型。
双编码器模型
在双编码器模型中,查询和文档会通过两个神经网络分别编码为向量表示,这两个神经网络通常架构相同。随后,利用相似性函数(比如点积或余弦相似度)对这些嵌入向量进行比较,以此检索相关文档。
优势
- 对于大规模检索非常高效,因为文档嵌入向量能预先计算并建立索引。
- 支持快速进行最近邻搜索。
弱点
在编码阶段,查询和文档之间缺乏细致的交互,这可能会影响检索质量。
跨编码器模型
与之不同的是,跨编码器模型会将查询和文档拼接起来,再通过单个transformer模型进行处理。这样可以实现更深入的交互,有可能得出更高质量的相关性分数。
优势
- 在查询和文档之间提供细致的注意力机制。
- 往往能产生更准确的相关性分数。
弱点
- 比双编码器慢很多,因为每一对查询 - 文档都必须从头开始编码。
- 不太适用于大规模数据集的扩展。
在RAG中实现双编码器和交叉编码器
步骤1:设置环境
在使用双编码器和跨编码器实现RAG之前,你需要准备以下这些库:
- Transformers(由Hugging Face提供):为双编码器和跨编码器设置提供预训练模型。
- FAISS(Facebook AI Similarity Search):支持对大规模数据集进行高效的相似性搜索,这是双编码器所必需的。
- Pytorch或TensorFlow:用于训练和模型管理。
- Datasets:用于加载和管理数据集(比如Hugging Face的datasets库)。
你可以使用以下命令安装这些必要的库:pip install transformers faiss-gpu datasets torch
步骤2:数据准备
为检索增强生成准备好数据集。你需要一个文档语料库以及相应的查询(比如问答对,或者用于生成的上下文文档)。
步骤3:实现双编码器
要创建双编码器模型,可以使用两个独立的transformer(或一个共享的transformer)为查询和文档生成嵌入向量。
定义双编码器模型:(补充定义模型的代码,如使用
transformers
库构建模型结构的代码示例)from transformers import AutoTokenizer, AutoModel import torch tokenizer = AutoTokenizer.from_pretrained('your_model_name') query_encoder = AutoModel.from_pretrained('your_model_name') document_encoder = AutoModel.from_pretrained('your_model_name')
对语料库和查询进行编码:(给出编码的代码示例,包括如何将文本转换为模型可接受的输入并获取嵌入向量)
def encode_text(text, encoder, tokenizer): inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True) with torch.no_grad(): outputs = encoder(**inputs) embeddings = outputs.pooler_output return embeddings
使用FAISS执行检索:计算完嵌入向量后,就可以用FAISS检索最相关的文档。(补充FAISS检索的代码,如构建索引、执行搜索等操作的示例)
import faiss # 假设已经有文档嵌入向量docs_embeddings和查询嵌入向量query_embedding d = docs_embeddings.shape[1] index = faiss.IndexFlatL2(d) index.add(docs_embeddings) distances, indices = index.search(query_embedding, k=5)
在RAG中使用检索到的文档:检索到的文档可以输入像BART或T5这样的生成模型,以生成最终输出。(给出将检索结果输入生成模型的代码示例)
from transformers import BartForConditionalGeneration, BartTokenizer model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') retrieved_docs = [docs[i] for i in indices[0]] input_text = " ".join(retrieved_docs) input_ids = tokenizer(input_text, return_tensors='pt').input_ids outputs = model.generate(input_ids) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
步骤4:实现跨编码器
对于跨编码器,必须将每对查询 - 文档一起处理,以生成相关性分数。
定义跨编码器模型:(给出定义跨编码器模型的代码,例如使用
transformers
库构建模型的代码片段)from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained('your_cross_encoder_model_name') cross_encoder = AutoModelForSequenceClassification.from_pretrained('your_cross_encoder_model_name')
评估查询 - 文档对:针对每个查询,评估每个候选文档的相关性分数。(补充评估分数的代码示例,包括如何输入数据到模型并获取分数)
def evaluate_pairs(query, docs, cross_encoder, tokenizer): scores = [] for doc in docs: inputs = tokenizer(query, doc, return_tensors='pt', padding=True, truncation=True) with torch.no_grad(): outputs = cross_encoder(**inputs) logits = outputs.logits score = torch.softmax(logits, dim=1)[0][1].item() scores.append(score) return scores
使用最佳文档进行生成:现在,可以像双编码器的工作流程一样,在生成模型中使用跨编码器排序后的最相关文档。(给出使用排序后的文档进行生成的代码示例)
# 假设已经有查询query、文档列表docs、跨编码器模型cross_encoder和分词器tokenizer scores = evaluate_pairs(query, docs, cross_encoder, tokenizer) sorted_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True) best_docs = [docs[i] for i in sorted_indices[:5]] # 后续将best_docs输入生成模型的代码与双编码器类似
步骤5:结合双编码器和跨编码器
可以先利用双编码器进行高效检索,再用跨编码器优化结果。这种混合方法兼具双编码器的可扩展性和跨编码器的精确性。
结论
在RAG中实现双编码器和跨编码器模型,能够实现高效且高质量的文档检索,显著提升生成任务的性能。双编码器支持快速的大规模检索,跨编码器则提供更精确的相关性评分,最终能生成更优质的内容。
## 推荐阅读
1. DeepSeek-R1的顿悟时刻是如何出现的? 背后的数学原理
2. 微调 DeepSeek LLM:使用监督微调(SFT)与 Hugging Face 数据
3. 使用 DeepSeek-R1 等推理模型将 RAG 转换为 RAT
4. DeepSeek R1:了解GRPO和多阶段训练
5. 深度探索:DeepSeek-R1 如何从零开始训练
6. DeepSeek 发布 Janus Pro 7B 多模态模型,免费又强大!
本文由mdnice多平台发布
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。