推理路径优化(RPO):提升大语言模型推理能力的新框架

📖阅读时长:19分钟

🕙发布时间:2025-02-12

近日热文:全网最全的神经网络数学原理(代码和公式)直观解释
欢迎关注知乎和公众号的专栏内容
LLM架构专栏
知乎LLM专栏
知乎【柏企
公众号【柏企科技说】【柏企阅文

OpenAI o1等大语言模型(LLM)通过分步推理展现出了令人瞩目的问题解决能力。然而,在面对更为复杂的问题时,它们依旧会遭遇困境,所犯的错误甚至会扰乱推理进程。这一现象的根源在于解决方案空间过于庞大,每一步推理都潜藏着出错的风险,具体可参考下图。

为攻克这一难题,本文引入了名为“推理路径优化(RPO)”的专业训练框架。该框架能够让模型从不同路径展开推理与探索,进而提升大语言模型的推理能力。它在推理的每一步中,对有利的分支加以鼓励,对不利的分支予以惩罚,以此提升模型整体解决问题的性能。RPO无需依赖大规模人工标注的推理依据,也不依赖闭源模型的输出,具备良好的扩展性,数据利用效率也很高。

任务制定

主要聚焦于那些需要多步推理才能得出最终答案或结果的问题,例如数学应用题。

给定一个以自然语言文本形式提出的问题Q,目标是用自然语言文本生成最终答案A。

假设模型要历经S1、S2、... 、Sn等多个推理步骤才能得到最终答案A。推理路径P被定义为这些步骤的序列:P = (S1,S2,... ,Sn),其中每个Si都是一个自然语言句子,且最后一步Sn必须包含答案A。

RPO框架概述

推理路径优化(RPO)由三个主要阶段构成,具体如下图所示:

  1. 生成:此阶段的目标是从基础模型中获取正确的推理步骤,将其作为参考推理路径,这样一来,就无需获取事实性推理路径注释了。
  2. 勘探:为有效探索给定问题的潜在解决方案空间,该阶段会沿着参考推理路径,从每一步逐步创建分支。如此,便能得到多个有利和不利的推理分支,这些分支会为模型提供对比反馈。
  3. 优化:这是最后一个阶段,会依据参考推理路径以及探索出的分支进行聚合和优化,以此增强基础模型的先天推理能力。

详细方法

  1. 推理生成:该框架起始于推理生成阶段,通过思维链提示法自动生成参考推理路径。思维链演示输入DCoT由m个真实示例构成,每个示例都是由一个问题及其对应的推理路径组成的对子。设M为基础模型,通过使用思维链演示DCoT、给定的问题Q以及固定温度T来对模型进行提示,从而对参考推理路径P进行采样。如果生成的路径以正确答案结尾,就认定其为正确路径。因此,定义如下函数F来验证最后一步Sn ∈ P是否包含真实答案A:

如果输出不正确,即F(Pi) = 0,就重复采样和验证过程,直至F(Pi) = 1,不过尝试次数上限为10次,即i ≤ 10。要是多次尝试后都未能获得合适的路径,那就判定该问题超出了模型的能力范围,会将其从训练集中剔除。

  1. 推理探索:给定问题Q、思维链演示,以及生成的推理路径P1:i−1 = (S1,S2,... ,Si−1)的前面步骤,利用温度采样法从推理路径的当前点获取不同分支。每个分支Bi = (S′i,S′i+1,... ,S′l)都应涵盖从当前步骤到最后一步的内容。目标是得到一个有利分支B+i和一个不利分支B−i,其中有利分支能导向正确的最终答案,而不利分支则不能。为实现这一目标,从每一步S′i开始迭代采样多个分支,并用函数F对每个分支进行验证,直至获得一个有利分支和一个不利分支,进而形成一个推理分支对(B+i ,B−i )。然而,如果在最多采样10个分支后仍无法形成分支对,就会将该问题从训练集中移除。
  2. 推理优化:为优化基础模型M,会同时考虑生成的参考推理路径P以及推理分支对(B+i ,B−i )。对参考推理路径P应用标准因果语言建模损失,以输入问题Q为条件

。在该框架中,为表述分支对损失,可以利用现有研究中基于偏好的目标,比如直接偏好目标或优势比目标。在这项研究中,第i步的分支对损失Lbp,i可计算为在输入问题Q和参考路径P的条件下,有利分支B+i和不利分支B−i之间的对数优势比(此处若有公式,详细列出公式内容)。生成分支的优势比可计算为在输入问题Q和参考路径前序步骤P1:i−1的条件下,生成该分支的概率与不生成它的概率之比

将推理路径中每个步骤对应的先前探索分支对的损失进行汇总

其中,推理路径有n个步骤。最后,该框架中的总体损失LRPO表示为参考路径损失Lref和探索损失Lexp的组合,探索损失Lexp会对探索出的分支对提供对比反馈

其中,λ是一个超参数权重,直观上用于平衡参考推理路径优化和探索分支优化之间的关系。

实验

  1. 实施:为评估RPO方法,选用Mistral - 7B和LLaMA - 3 - 8B作为基础模型,它们分别是Mistral和LLaMA模型家族中较新且受欢迎的基础大语言模型。在模型训练时,采用LoRA微调方法,批量大小固定为8,学习率设为5e - 5。为从模型中采样多个输出,使用固定的采样温度0.5。在评估时,采用贪心解码进行生成,并使用准确率指标进行评分。
  2. 对比方法:与包括特定推理训练方法和基于偏好的优化方法在内的强大基线模型进行对比:

    • 监督微调(SFT):考虑不使用任何推理路径进行训练的情况,仅训练模型直接生成真实的最终答案。
    • 拒绝采样微调(RFT):将RFT作为监督训练的强大基线,它利用模型自行生成推理路径用于训练。这种方法与本框架中的推理生成阶段类似,旨在克服缺乏真实推理路径的数据限制。
    • 直接偏好优化(DPO):由于我们的方法会对比有利和不利的推理分支,其动机与为模型提供对比反馈的DPO相似。
    • 优势比偏好优化(ORPO):我们的方法与ORPO的主要区别在于,推理路径优化是一个专门为基于推理的任务设计的整体框架。我们认为推理错误可能在推理路径的任何一步发生,因此探索可能的解决方案路径对于提供不同推理分支对的对比反馈是必要的。
  3. 结果:下图展示了不同训练方法在GSM8K和MATH中的数学推理问题,以及MMLU - STEM中的科学类考试问题上的评估准确率(%)。结果显示,RPO方法在不同数据集和模型上的性能都有持续提升。尤其在以Mistral - 7B为基础进行训练时,与表现最佳的基线模型相比,我们的方法在GSM8K和MMLU - STEM上分别可实现高达3.1%和4.3%的性能提升。同时还发现,与其他基于自我探索推理路径训练的方法相比,SFT的性能较低。这表明,虽然模型有可能在不经过任何推理步骤的情况下直接生成答案,但对于更复杂的推理问题,这种方式的效果较差。

  1. 探索权重的影响:λ值较低时,会更侧重于参考路径上导向正确答案的监督损失。另一方面,λ值较高时,训练过程中会更重视探索分支,这些分支对比了每个推理步骤产生的有利和不利分支。下图展示了探索损失权重对LLaMA - 3 - 8B在MATH数据集上性能的影响,这表明λ值过低会导致结果不理想,因为它没有对推理探索给予足够的重视。

  1. 推理路径长度分析:下图展示了LLaMA - 3 - 8B在MATH数据集上的性能与推理路径长度的关系。结果表明,我们的方法在解决需要更复杂推理的问题时,能够有效减少错误的发生。

  1. 基于代码的推理:下表展示了以基于文本的推理作为主要设置,以及通过Python程序进行基于代码的推理的性能优势分析。实验使用LLaMA - 3 - 8B进行。结果表明,与最强的基线模型ORPO相比,基于文本的推理和基于代码的推理都能从我们的方法中获得类似的性能提升。

  1. 对比目标的影响:下表展示了使用LLaMA - 3 - 8B在我们框架中不同对比目标下,在GSM8K、MATH和MMLU - STEM数据集上的性能比较。结果显示,不同目标下都有一致的性能提升,这表明我们的框架具有很强的鲁棒性,优于各个基线模型。

  1. 参考路径的影响:下表展示了使用LLaMA - 3 - 8B在GSM8K上不同参考路径的性能比较。结果表明,即使参考路径有所变化,我们的方法依然有效,证明了它在不同参考路径下的鲁棒性。

局限性

RPO框架依赖模型在训练阶段生成正确推理路径的能力。如果基础模型表现不佳,可能难以生成必要的正确路径,从而限制了该方法的有效性。虽然为每个问题生成和探索多个推理路径的过程计算成本较高,但这只是训练阶段的一次性成本。因此,我们认为这是为提升性能而做出的值得的权衡,并且这种成本可以在多次推理过程中分摊。

结论

本文介绍了一种名为推理路径优化(RPO)的全新训练框架,用于提升大语言模型的分步推理能力。该方法应对了复杂问题解决任务中的挑战,在这类任务中,每一步推理都有出错的风险。RPO考虑了不同的推理分支对,在每个推理步骤中鼓励有利分支,同时惩罚不利分支。我们的框架具有可扩展性,因为它不依赖大规模人工标注的推理依据,而是利用模型自身生成的推理路径,能够很好地适应诸如数学应用题这类多步推理任务。

论文链接:https://arxiv.org/abs/2410.10858
代码链接:https://github.com/DAMO-NLP-SG/reasoning-paths-optimization

参考文献 Reasoning Paths Optimization: Learning to Reason and Explore From Diverse Paths by Chia等. arXiv:2410.10858

本文由mdnice多平台发布


柏企科技圈
23 声望5 粉丝