MMOA-RAG 用多智能体强化学习优化检索增强生成模型,效果提升显著!
原文链接:Improving Retrieval-Augmented Generation through Multi-Agent Reinforcement Learning
代码链接:https://github.com/chenyiqun/MMOA-RAG
📖阅读时长:25分钟
🕙发布时间:2025-02-06
近日热文:全网最全的神经网络数学原理(代码和公式)直观解释
欢迎关注知乎和公众号的专栏内容
LLM架构专栏
知乎LLM专栏
知乎【柏企】
公众号【柏企科技说】【柏企阅文】
检索增强生成(RAG)管道组件通常通过监督微调单独优化,这可能会导致各个模块的目标与在问答(QA)任务中生成准确答案的总体目标不一致。
为了解决这个问题,本文提出将RAG管道视为一个多智能体合作任务,把每个组件都看作是一个强化学习(RL)智能体,推出了MMOA-RAG,即一种用于RAG的多模块联合优化算法。它采用多智能体强化学习的方式,让所有智能体朝着统一的奖励目标(比如最终答案的F1分数)努力。
前期准备
i) 把RAG建模为协同多智能体强化学习(Co-MARL)
本文将RAG过程放在协同多智能体强化学习(Co-MARL)框架中进行概念化理解。
在这个框架里,RAG管道的每个模块都作为一个独立的RL智能体。这个多智能体系统的总体目标是生成高质量答案,这和每个模块的个体目标是一致的。
定义元组$\langle G, O, A, R\rangle$,其中$G$表示Co-MARL系统中的智能体集合,$O$代表每个智能体可获取的观察信息,$A$构成每个智能体可触及的行动空间,$R$是所有智能体共享的奖励 。
ii) MAPPO算法
本文使用多智能体近端策略优化算法(Multi-Agent PPO,MAPPO),它是近端策略优化算法(PPO)在多智能体环境下的扩展,用于优化Co-MARL框架中每个智能体的策略。
与专注于单智能体场景、使用个体奖励的PPO不同,MAPPO采用共享的全局奖励来促进所有智能体之间的合作。
在PPO中,价值评估模型(critic model)受限于智能体的观察信息,而MAPPO的全局价值评估模型可以获取全面的全局信息,这使得全局价值评估模型能够更准确地估计状态价值函数。
方法概述
下图展示了MMOA-RAG框架的架构,它主要由四个模块组成:查询改写器(Query Rewriter)、检索器(Retriever)、选择器(Selector)和生成器(Generator):
i) 查询改写器
它会重新组织初始查询$q$。初始查询可能过于复杂或模糊,无法通过一次检索解决,查询改写器会将其转化为一组子问题,记为$subq$。
ii) 检索器
检索器会针对每个子问题,分别从语料库中检索相关文档,并输出一组候选文档$D$。
iii) 选择器
选择器会对$D$进一步筛选,得到文档子集$D_{selected}$,这些文档对生成初始查询$q$的最终答案很有用。
iv) 生成器
生成器利用$D_{selected}$,生成对初始问题的预测答案$Ans_{predict}$。
MMOA-RAG框架的重点在于多个模块的协同优化,使它们各自的优化目标与生成高质量答案的最终目标保持一致。
每个智能体的详细配置
i) 查询改写器的要素
- 观察值:定义如下公式,它包含查询改写器的提示$Prompt_{QR}$和初始问题$q$ 。
- 行动空间:与大语言模型(LLMs)的词汇表$V$相对应。
- 奖励函数:定义如下公式。$R_{shared}$可以是衡量最终答案的指标,$P_{QR}$项作为惩罚项,用于防止查询改写器在训练过程中生成过多子问题。如果子问题数量超过四个,$P_{QR}$赋值为 -0.5;若子问题数量为四个或更少,则设置为0。
ii) 选择器的要素
- 观察值:定义如下公式,它包含选择器的提示$Prompt_{S}$、初始问题$q$以及包含$K$个文档的候选文档集$D$。
- 行动空间:由于选择器的功能是输出有助于回答初始问题$q$的候选文档ID,所以行动空间被限制在这个有限的词汇集合中,定义如下。
- 奖励函数:包含两项,即$R_{shared}$和$P_{S}$ 。$P_{S}$是惩罚项,用于防止选择器生成重复的文档ID,以及输出不符合指定格式(例如Document0,Document3,Document9)的ID。当选择器输出重复文档ID或未遵循指定格式时,$P_{S}$设置为 -1;否则,$P_{S}$设置为0。
iii) 生成器的要素
- 观察值:定义如下公式,包含生成器的提示$Prompt_{G}$、初始问题$q$以及选择器给出的选定候选文档集$D_{selected}$。
- 行动空间:定义如下公式,与查询改写器的行动空间相同。
- 奖励函数:包含$R_{shared}$和惩罚项$P_{G}$,$P_{G}$用于限制模型生成过长的内容。当生成的答案超过一定长度时,$P_{G}$设置为 -0.5;否则,设置为0。
使用监督微调(SFT)进行热启动
热启动能让模型在不同任务中更好地遵循指令,减少多智能体强化学习(MARL)联合训练时的探索空间,从而提高探索和利用的效率。
i) 查询改写器
使用近端策略优化算法(PPO)训练一个小型语言模型,以便有效地为RAG改写查询。在此基础上,利用来自Rewrite-Retrieve-Read的公开查询改写数据作为监督微调(SFT)数据集,对MMOA-RAG中的查询改写器进行热启动。
ii) 选择器
选择器的任务是从给定的包含$K$个候选文档的集合$D$中,选择出有助于回答问题的子集$D_{selected}$。选择器的输出格式是$D_{selected}$中文档的ID(例如Document0, Document4, Document6, Document7),如下表中选择器的提示所示:
角色 | 内容 |
---|---|
系统 | 你是一个乐于助人、尊重他人且诚实的助手。你的任务是输出有助于回答问题的候选文档ID(0,1,2,...,K - 1)。 |
助手 | 好的,我会提供有助于回答问题的候选文档ID。 |
用户 | 问题:{问题内容} |
用户 | Document0:{Document0的内容} |
... | ... |
用户 | Document(K - 1):{Document(K - 1)的内容} |
助手 | 好的,我收到了问题和候选文档。 |
用户 | 现在,输出有助于回答问题“{问题内容}”的候选文档ID(0,1,2,...,K - 1),例如按照以下格式:Document0,Document4,Document6,Document7。 |
- 为选择器构建SFT数据:提出了一种简便的启发式方法来构建SFT数据,旨在让大语言模型有效地遵循指令并按期望的形式输出,如下图所示并进一步解释:
对于给定的问题$q_{i}$及其标准答案,有$K$个候选文档记为$d_{i,j}$,其中$j \in \{0, 1, · · · ,K - 1\}$。首先,从$q_{i}$及其标准答案中去除某些无意义的停用词和标点符号,并将单词转换为小写,得到集合$Set_{q_{i}}$。同样地,对$K$个候选文档$d_{i,j}$执行相同操作,得到$Set_{d_{i,j}}$。最后,如果$Set_{q_{i}}$中的任何单词出现在$Set_{d_{i,j}}$中,那么相应文档$j$的ID就会被包含在最终输出中,作为SFT的标签。
iii) 生成器
生成器负责根据选择器提供的$D_{selected}$生成最终答案$Ans_{predict}$。生成器SFT数据的真实值是标准答案$Ans_{golden}$。
多智能体优化
采用与《星际争霸II》中多智能体近端策略优化类似的设置,多个智能体共享一个全局奖励,即使用$R_{shared}$来优化$G$ = {查询改写器(QR)、选择器(S)、生成器(G)}。
为了降低计算开销,在智能体之间应用参数共享机制,让QR、S和G能够使用相同的大语言模型。
i) 模型
在多智能体优化过程中,需要考虑三个模型:
- 策略模型(Actor model):参数记为$\theta$。其作用是基于每个智能体$i$的观察值$O_{i}$提供响应$Answer_{i}$。
- 价值评估模型(Critic model):参数记为$\phi$。负责估计状态价值函数$V_{i,t}^{\phi}$,这是强化学习算法中策略 - 价值(Actor-Critic)架构的经典设置。
- 监督微调模型(SFT model):参数记为$\theta_{SFT}$。作为策略模型的基线,类似于InstructGPT。
ii) 损失函数
目标是更新策略模型和价值评估模型的参数。
总体损失函数$L(\theta, \phi)$由两项组成:$L_{Actor}(\theta)$和$L_{Critic}(\phi)$。
策略模型的损失函数如下公式(12)所示,与典型的单智能体近端策略优化算法中使用的损失函数类似,主要区别在于这里是对多个智能体进行优化。
在公式(12)中,$i \in G$表示三个智能体:查询改写器、选择器和生成器。公式(13)中的$r_{i}^{t}$表示重要性采样比,用于衡量新旧策略之间的差异。
最终奖励函数$R(s_{i}^{t}, a_{i}^{t})$定义如下公式(16)。
价值评估模型的损失函数,如公式(17)所示,采用了与策略模型类似的裁剪操作。
公式中的$V_{i,t}^{target}$代表累积回报,$s_{i}^{t}$是状态价值函数。
iii) 训练过程
基于MAPPO的多智能体优化伪代码如下算法所示,它与前面描述的MMOA-RAG总体框架相对应。
对于一个具体问题,第一步是执行收集轨迹(Collect Rollout)过程。它涉及依次经过查询改写器、检索器、选择器和生成器,计算得到的元组$T = (O_{QR}, subq, R_{QR}), (O_{S}, ID_{s}, R_{S}), (O_{G}, Ans_{predict}, R_{G})$会存储在经验回放缓冲区$M$中。
接下来,执行策略和价值优化(Policy and Value Optimization)过程,其中使用广义优势估计(GAE)来估计优势函数$\hat{A}_{i,t}^{\pi^{\theta}}$。
随后,计算总体损失函数$L(\theta, \phi)$,并更新策略模型和价值评估模型的参数。
此外,为了加速整个训练过程,可以并行运行小批量数据,得到一个训练良好的策略模型,用于后续的推理和评估。
实验
i) 与其他方法的比较
下表展示了不同方法在多个数据集上的性能,结果是以Llama-3–8B-Instruct为基础模型得到的。最后一行“Δ”显示了MMOA-RAG相对于最佳基线方法的提升。
MMOA-RAG在所有指标和数据集上都表现出色,凸显了其有效性。
MMOA-RAG可以看作是对普通RAG的增强,它集成了查询改写器和选择器,其作用类似于Rewrite-Retrieve-Read中的查询改写模块和BGM中的桥梁模块。
ii) 不同模块配置的RAG系统通用性实验
下表展示了不同模块配置的RAG系统通用性实验结果。符号“Δ”表示在MAPPO阶段相对于监督微调(SFT)阶段的提升。
“QR+S+G”代表MMOA-RAG,它展示了一个由三个智能体(查询改写器、选择器和生成器)组成的多智能体系统。
配置“S+G”是省略了查询改写器智能体,仅依靠初始问题$q$进行检索,从而将RAG系统配置为一个双智能体(选择器和生成器)系统。
相反,“QR+G”表示排除选择器智能体,形成一个由查询改写器和生成器两个智能体组成的RAG管道。
上表第二列明确指出,SFT指的是通过监督微调对相应RAG系统中的所有智能体进行热启动,而MAPPO指的是在SFT的基础上,利用MAPPO框架对所有智能体进行联合优化。
结果显示,在所有数据集上,使用联合MAPPO优化的RAG系统始终优于仅使用SFT的系统。
iii) 跨领域实验
下表展示了跨领域实验结果:模型在HotpotQA数据集上进行训练,在AmbigQA数据集上进行测试。
结果表明,MMOA-RAG在跨领域(OOD)实验中表现出色,突出了其显著的泛化能力。
结论
将RAG系统建模为一个多智能体协作任务,把查询改写器、选择器和生成器模块视为可学习的RL智能体。
采用多智能体强化学习算法对这些智能体进行联合优化,使多个模块的优化目标与生成高质量答案的最终目标保持一致。
实验证明了这种建模方法和联合优化方法的有效性。
本文由mdnice多平台发布
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。