MMOA-RAG 用多智能体强化学习优化检索增强生成模型,效果提升显著!

原文链接:Improving Retrieval-Augmented Generation through Multi-Agent Reinforcement Learning
代码链接:https://github.com/chenyiqun/MMOA-RAG

📖阅读时长:25分钟

🕙发布时间:2025-02-06

近日热文:全网最全的神经网络数学原理(代码和公式)直观解释
欢迎关注知乎和公众号的专栏内容
LLM架构专栏
知乎LLM专栏
知乎【柏企
公众号【柏企科技说】【柏企阅文

检索增强生成(RAG)管道组件通常通过监督微调单独优化,这可能会导致各个模块的目标与在问答(QA)任务中生成准确答案的总体目标不一致。

为了解决这个问题,本文提出将RAG管道视为一个多智能体合作任务,把每个组件都看作是一个强化学习(RL)智能体,推出了MMOA-RAG,即一种用于RAG的多模块联合优化算法。它采用多智能体强化学习的方式,让所有智能体朝着统一的奖励目标(比如最终答案的F1分数)努力。

前期准备

i) 把RAG建模为协同多智能体强化学习(Co-MARL)

本文将RAG过程放在协同多智能体强化学习(Co-MARL)框架中进行概念化理解。

在这个框架里,RAG管道的每个模块都作为一个独立的RL智能体。这个多智能体系统的总体目标是生成高质量答案,这和每个模块的个体目标是一致的。

定义元组$\langle G, O, A, R\rangle$,其中$G$表示Co-MARL系统中的智能体集合,$O$代表每个智能体可获取的观察信息,$A$构成每个智能体可触及的行动空间,$R$是所有智能体共享的奖励 。

ii) MAPPO算法

本文使用多智能体近端策略优化算法(Multi-Agent PPO,MAPPO),它是近端策略优化算法(PPO)在多智能体环境下的扩展,用于优化Co-MARL框架中每个智能体的策略。

与专注于单智能体场景、使用个体奖励的PPO不同,MAPPO采用共享的全局奖励来促进所有智能体之间的合作。

在PPO中,价值评估模型(critic model)受限于智能体的观察信息,而MAPPO的全局价值评估模型可以获取全面的全局信息,这使得全局价值评估模型能够更准确地估计状态价值函数。

方法概述

下图展示了MMOA-RAG框架的架构,它主要由四个模块组成:查询改写器(Query Rewriter)、检索器(Retriever)、选择器(Selector)和生成器(Generator):

i) 查询改写器

它会重新组织初始查询$q$。初始查询可能过于复杂或模糊,无法通过一次检索解决,查询改写器会将其转化为一组子问题,记为$subq$。

ii) 检索器

检索器会针对每个子问题,分别从语料库中检索相关文档,并输出一组候选文档$D$。

iii) 选择器

选择器会对$D$进一步筛选,得到文档子集$D_{selected}$,这些文档对生成初始查询$q$的最终答案很有用。

iv) 生成器

生成器利用$D_{selected}$,生成对初始问题的预测答案$Ans_{predict}$。

MMOA-RAG框架的重点在于多个模块的协同优化,使它们各自的优化目标与生成高质量答案的最终目标保持一致。

每个智能体的详细配置

i) 查询改写器的要素
  • 观察值:定义如下公式,它包含查询改写器的提示$Prompt_{QR}$和初始问题$q$ 。
  • 行动空间:与大语言模型(LLMs)的词汇表$V$相对应。
  • 奖励函数:定义如下公式。$R_{shared}$可以是衡量最终答案的指标,$P_{QR}$项作为惩罚项,用于防止查询改写器在训练过程中生成过多子问题。如果子问题数量超过四个,$P_{QR}$赋值为 -0.5;若子问题数量为四个或更少,则设置为0。
ii) 选择器的要素
  • 观察值:定义如下公式,它包含选择器的提示$Prompt_{S}$、初始问题$q$以及包含$K$个文档的候选文档集$D$。

  • 行动空间:由于选择器的功能是输出有助于回答初始问题$q$的候选文档ID,所以行动空间被限制在这个有限的词汇集合中,定义如下。

  • 奖励函数:包含两项,即$R_{shared}$和$P_{S}$ 。$P_{S}$是惩罚项,用于防止选择器生成重复的文档ID,以及输出不符合指定格式(例如Document0,Document3,Document9)的ID。当选择器输出重复文档ID或未遵循指定格式时,$P_{S}$设置为 -1;否则,$P_{S}$设置为0。

iii) 生成器的要素
  • 观察值:定义如下公式,包含生成器的提示$Prompt_{G}$、初始问题$q$以及选择器给出的选定候选文档集$D_{selected}$。

  • 行动空间:定义如下公式,与查询改写器的行动空间相同。

  • 奖励函数:包含$R_{shared}$和惩罚项$P_{G}$,$P_{G}$用于限制模型生成过长的内容。当生成的答案超过一定长度时,$P_{G}$设置为 -0.5;否则,设置为0。

使用监督微调(SFT)进行热启动

热启动能让模型在不同任务中更好地遵循指令,减少多智能体强化学习(MARL)联合训练时的探索空间,从而提高探索和利用的效率。

i) 查询改写器

使用近端策略优化算法(PPO)训练一个小型语言模型,以便有效地为RAG改写查询。在此基础上,利用来自Rewrite-Retrieve-Read的公开查询改写数据作为监督微调(SFT)数据集,对MMOA-RAG中的查询改写器进行热启动。

ii) 选择器

选择器的任务是从给定的包含$K$个候选文档的集合$D$中,选择出有助于回答问题的子集$D_{selected}$。选择器的输出格式是$D_{selected}$中文档的ID(例如Document0, Document4, Document6, Document7),如下表中选择器的提示所示:

角色内容
系统你是一个乐于助人、尊重他人且诚实的助手。你的任务是输出有助于回答问题的候选文档ID(0,1,2,...,K - 1)。
助手好的,我会提供有助于回答问题的候选文档ID。
用户问题:{问题内容}
用户Document0:{Document0的内容}
......
用户Document(K - 1):{Document(K - 1)的内容}
助手好的,我收到了问题和候选文档。
用户现在,输出有助于回答问题“{问题内容}”的候选文档ID(0,1,2,...,K - 1),例如按照以下格式:Document0,Document4,Document6,Document7。
  • 为选择器构建SFT数据:提出了一种简便的启发式方法来构建SFT数据,旨在让大语言模型有效地遵循指令并按期望的形式输出,如下图所示并进一步解释:

对于给定的问题$q_{i}$及其标准答案,有$K$个候选文档记为$d_{i,j}$,其中$j \in \{0, 1, · · · ,K - 1\}$。首先,从$q_{i}$及其标准答案中去除某些无意义的停用词和标点符号,并将单词转换为小写,得到集合$Set_{q_{i}}$。同样地,对$K$个候选文档$d_{i,j}$执行相同操作,得到$Set_{d_{i,j}}$。最后,如果$Set_{q_{i}}$中的任何单词出现在$Set_{d_{i,j}}$中,那么相应文档$j$的ID就会被包含在最终输出中,作为SFT的标签。

iii) 生成器

生成器负责根据选择器提供的$D_{selected}$生成最终答案$Ans_{predict}$。生成器SFT数据的真实值是标准答案$Ans_{golden}$。

多智能体优化

采用与《星际争霸II》中多智能体近端策略优化类似的设置,多个智能体共享一个全局奖励,即使用$R_{shared}$来优化$G$ = {查询改写器(QR)、选择器(S)、生成器(G)}。

为了降低计算开销,在智能体之间应用参数共享机制,让QR、S和G能够使用相同的大语言模型。

i) 模型

在多智能体优化过程中,需要考虑三个模型:

  • 策略模型(Actor model):参数记为$\theta$。其作用是基于每个智能体$i$的观察值$O_{i}$提供响应$Answer_{i}$。
  • 价值评估模型(Critic model):参数记为$\phi$。负责估计状态价值函数$V_{i,t}^{\phi}$,这是强化学习算法中策略 - 价值(Actor-Critic)架构的经典设置。
  • 监督微调模型(SFT model):参数记为$\theta_{SFT}$。作为策略模型的基线,类似于InstructGPT。
ii) 损失函数

目标是更新策略模型和价值评估模型的参数。

总体损失函数$L(\theta, \phi)$由两项组成:$L_{Actor}(\theta)$和$L_{Critic}(\phi)$。

策略模型的损失函数如下公式(12)所示,与典型的单智能体近端策略优化算法中使用的损失函数类似,主要区别在于这里是对多个智能体进行优化。

在公式(12)中,$i \in G$表示三个智能体:查询改写器、选择器和生成器。公式(13)中的$r_{i}^{t}$表示重要性采样比,用于衡量新旧策略之间的差异。

最终奖励函数$R(s_{i}^{t}, a_{i}^{t})$定义如下公式(16)。

价值评估模型的损失函数,如公式(17)所示,采用了与策略模型类似的裁剪操作。

公式中的$V_{i,t}^{target}$代表累积回报,$s_{i}^{t}$是状态价值函数。

iii) 训练过程

基于MAPPO的多智能体优化伪代码如下算法所示,它与前面描述的MMOA-RAG总体框架相对应。

对于一个具体问题,第一步是执行收集轨迹(Collect Rollout)过程。它涉及依次经过查询改写器、检索器、选择器和生成器,计算得到的元组$T = (O_{QR}, subq, R_{QR}), (O_{S}, ID_{s}, R_{S}), (O_{G}, Ans_{predict}, R_{G})$会存储在经验回放缓冲区$M$中。

接下来,执行策略和价值优化(Policy and Value Optimization)过程,其中使用广义优势估计(GAE)来估计优势函数$\hat{A}_{i,t}^{\pi^{\theta}}$。

随后,计算总体损失函数$L(\theta, \phi)$,并更新策略模型和价值评估模型的参数。

此外,为了加速整个训练过程,可以并行运行小批量数据,得到一个训练良好的策略模型,用于后续的推理和评估。

实验

i) 与其他方法的比较

下表展示了不同方法在多个数据集上的性能,结果是以Llama-3–8B-Instruct为基础模型得到的。最后一行“Δ”显示了MMOA-RAG相对于最佳基线方法的提升。

MMOA-RAG在所有指标和数据集上都表现出色,凸显了其有效性。

MMOA-RAG可以看作是对普通RAG的增强,它集成了查询改写器和选择器,其作用类似于Rewrite-Retrieve-Read中的查询改写模块和BGM中的桥梁模块。

ii) 不同模块配置的RAG系统通用性实验

下表展示了不同模块配置的RAG系统通用性实验结果。符号“Δ”表示在MAPPO阶段相对于监督微调(SFT)阶段的提升。

“QR+S+G”代表MMOA-RAG,它展示了一个由三个智能体(查询改写器、选择器和生成器)组成的多智能体系统。

配置“S+G”是省略了查询改写器智能体,仅依靠初始问题$q$进行检索,从而将RAG系统配置为一个双智能体(选择器和生成器)系统。

相反,“QR+G”表示排除选择器智能体,形成一个由查询改写器和生成器两个智能体组成的RAG管道。

上表第二列明确指出,SFT指的是通过监督微调对相应RAG系统中的所有智能体进行热启动,而MAPPO指的是在SFT的基础上,利用MAPPO框架对所有智能体进行联合优化。

结果显示,在所有数据集上,使用联合MAPPO优化的RAG系统始终优于仅使用SFT的系统。

iii) 跨领域实验

下表展示了跨领域实验结果:模型在HotpotQA数据集上进行训练,在AmbigQA数据集上进行测试。

结果表明,MMOA-RAG在跨领域(OOD)实验中表现出色,突出了其显著的泛化能力。

结论

将RAG系统建模为一个多智能体协作任务,把查询改写器、选择器和生成器模块视为可学习的RL智能体。

采用多智能体强化学习算法对这些智能体进行联合优化,使多个模块的优化目标与生成高质量答案的最终目标保持一致。

实验证明了这种建模方法和联合优化方法的有效性。

本文由mdnice多平台发布


柏企科技圈
1 声望0 粉丝

时间差不多了,快上车!~