Chinchilla,如何打破大语言模型训练 “魔咒”?
🕙发布时间:2025-02-25
更多LLM架构文章:LLM架构专栏
近日热文:
1. 全网最全的神经网络数学原理(代码和公式)直观解释
2. 大模型进化史:从Transformer到DeepSeek-R1的AI变革之路
3. 2W8000字深度剖析25种RAG变体:全网最全~没有之一
4. 3W6000字了解大模型LLM:部署、优化与框架
知乎【柏企】
公众号【柏企科技说】【柏企阅文】
本文研究了在给定计算预算的情况下,训练 Transformer 架构的大语言模型(LLM)的最优模型规模和训练 token 数量。研究发现,由于当前在保持训练数据量不变的同时过度强调扩大模型规模,现有的大语言模型并没有得到充分训练。
通过在 50 亿到 5000 亿个 token 上训练 400 多个参数规模从 7000 万到超过 160 亿的语言模型,我们发现,为了实现计算最优的训练,模型规模和训练 token 的数量应该同等缩放:模型规模每扩大一倍,训练 token 的数量也应该翻倍。
作者通过训练一个预测的计算最优模型 Chinchilla(龙猫)来验证这一假设。Chinchilla 与 Gopher 使用相同的计算预算,但它有 700 亿个参数,并且训练数据量是 Gopher 的 4 倍。在一系列下游评估任务中,Chinchilla 普遍且显著地优于 Gopher(2800 亿参数)、GPT-3(1750 亿参数)、Jurassic-1(1780 亿参数)和 Megatron-Turing NLG(5300 亿参数)。
这也意味着 Chinchilla 在微调与推理时所需的计算量大幅减少,极大地方便了下游应用。值得一提的是,Chinchilla 在 MMLU 基准测试中达到了 67.5%的最先进平均准确率,比 Gopher 提高了超过 7%。
研究表明,当前的大模型规模应该大幅缩小,并且训练时间要比现在长得多。
估计最优参数/训练 token 数量
本文提出了三种不同的方法来回答推动这项研究的问题:在给定的浮点运算次数(FLOPs)预算下,应该如何在模型规模和训练 token 数量之间进行权衡?
在这三种方法中,他们首先训练一系列模型,同时改变模型规模和训练 token 的数量,然后利用得到的训练曲线来拟合一个经验估计器,以确定它们应该如何缩放。虽然未来的研究可能会考虑在大规模模型中这种关系的潜在曲线,但这里假设计算量和模型规模之间存在幂律关系。
在第一种方法中,他们针对固定的一系列模型(参数规模从 7000 万到超过 100 亿)改变训练步数,对每个模型使用 4 种不同数量的训练序列进行训练。从这些实验中,他们能够直接得出在给定训练浮点运算次数下可达到的最小损失的估计值。
在第二种方法中,他们针对 9 种不同的固定训练浮点运算次数(从 6×10¹⁸ 到 3×10²¹ 次浮点运算)改变模型规模,并考虑每个点的最终训练损失。这使他们能够直接回答这个问题:对于给定的浮点运算次数预算,最优的参数数量是多少?
最后,他们将方法 1 和方法 2 实验中的所有最终损失建模为模型参数数量和已见 token 数量的参数化函数。
作者发现,尽管这三种方法使用了不同的拟合方法和不同的训练模型,但对于参数、token 数量和浮点运算次数的最优缩放,它们给出了相当的预测结果。
各种模型规模的估计最优训练浮点运算次数和训练 token 数量
模型
作者使用 MassiveText(与 Gopher 相同的数据集)训练 Chinchilla,但为了适应增加的训练 token 数量,使用了略有不同的子集分布。
他们在训练 Chinchilla 时使用了 AdamW 优化器,而不是 Adam,因为这可以降低语言建模损失,并提高微调后的下游任务性能。
Chinchilla 使用了略微修改的 SentencePiece 分词器,该分词器不应用 NFKC 规范化。词汇表非常相似,94.15%的 token 与训练 Gopher 时使用的 token 相同。研究结果表明,这对数学和化学相关内容的表示特别有帮助。
虽然前向和反向传播是用 bfloat16 计算的,但我们在分布式优化器状态中存储了权重的 float32 副本。
Chinchilla 架构细节
结果
- 所有评估任务:我们在一系列语言建模任务以及下游任务上对 Chinchilla 进行评估。
- 语言建模
- The Pile 评估:对于 The Pile 中的不同评估集,我们展示了 Chinchilla 相对于 Gopher 在每字节比特数(bpb)上的改进(降低)情况。在所有子集中,Chinchilla 的表现都优于 Gopher。
- MMLU
- 大规模多任务语言理解(MMLU):我们报告了在 57 个任务上的平均 5-shot 准确率,并与模型和人类的准确率进行了比较。我们还纳入了 73 位有竞争力的人类预测者对 2022 年 6 月/2023 年最先进准确率的平均预测。 - 与 Gopher 相比的 MMLU 结果:我们发现,Chinchilla 的平均准确率比 Gopher 高 7.6%,在 57 个单独任务中,Chinchilla 在 51 个任务上表现更好,2 个任务表现相同,只有 4 个任务表现更差。
- 阅读理解:在 RACE-h 和 RACE-m 数据集上,Chinchilla 的性能比 Gopher 有显著提升。需要注意的是,GPT-3 和 MT-NLG 530B 在 RACE-h/m 上使用的提示格式与我们不同,所以其结果与 Gopher 和 Chinchilla 不具有可比性。在 LAMBADA 数据集上,Chinchilla 的表现优于 Gopher 和 MT-NLG 530B。
- Big Bench - 与 Gopher 相比的 BIG-bench 结果:在几乎所有考虑的 BIG-bench 任务中,Chinchilla 的表现都优于 Gopher。
常识
- 常识基准测试中的零样本比较:我们展示了 Chinchilla、Gopher 和 MT-NLG 530B 在各种常识基准测试中的比较。我们发现,Chinchilla 在所有任务上的表现与 Gopher 和 GPT-3 相当或更优。除了一项任务外,Chinchilla 在其他所有任务上的表现都优于参数多得多的 MT-NLG 530B 模型。
- 闭卷问答:在自然问题(Natural Questions)和琐事问答(TriviaQA)任务中,Chinchilla 在所有情况下的表现都优于 Gopher。在自然问题任务中,Chinchilla 的表现优于 GPT-3。在琐事问答任务中,我们展示了在两个不同评估集上的结果,以便与 GPT-3 以及开卷问答的最先进模型进行比较。
论文
Training Compute-Optimal Large Language Models 2203.15556
## 推荐阅读
1. DeepSeek-R1的顿悟时刻是如何出现的? 背后的数学原理
2. 微调 DeepSeek LLM:使用监督微调(SFT)与 Hugging Face 数据
3. 使用 DeepSeek-R1 等推理模型将 RAG 转换为 RAT
4. DeepSeek R1:了解GRPO和多阶段训练
5. 深度探索:DeepSeek-R1 如何从零开始训练
6. DeepSeek 发布 Janus Pro 7B 多模态模型,免费又强大!
本文由mdnice多平台发布
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。