主要观点:在生产中使用大型语言模型(LLM)进行推理并不容易,需要大量计算且成本较高,因此出现了多种提高推理效率的技术,如推测解码(speculative decoding)。
关键信息:
- 传统 LLM 推理效率低,一次前向传播仅生成一个令牌,大量计算资源未被充分利用。
- 推测解码受传统处理器推测执行概念启发,通过使用较小的推测模型在一次前向传播中预测多个令牌来提高 GPU 资源利用率和推理吞吐量。
- 原始推测解码论文提出使用较小的推测模型预测下几个令牌,过程包括推测生成令牌、用基础模型验证以及并行推理生成更多令牌等。
- Medusa 架构通过在基础 LLM 顶部添加多个预测头来进行推测解码,无需单独的推测模型,预测头需微调以匹配基础模型,还提出了两种微调方法。
- Medusa 架构在推测解码时会使用树注意力来避免处理无效延续,采用不同的令牌接受方案,以接受更典型的候选令牌。
- 研究发现 Medusa 架构有一定局限性,如多个预测头接收相同输入导致预测偏离基础模型行为,研究者提出用多阶段 MLP 替换预测头并采用多候选解码等修改方案。
重要细节: - 原始推测解码实验中使用 LaMDA-137B 作为基础模型,LaMDA-100M 作为推测模型在对话任务中进行实验。
- Medusa 架构中每个预测头在给定最后隐藏层输出作为输入时预测下一个令牌,每个头产生多个预测并选择 top-k 。
- 微调 Medusa 预测头的两种方法,一种是冻结基础模型层最小化交叉熵损失,另一种是联合训练基础模型和预测头。
- 验证延续时使用基于基础模型预测概率和阈值的方案,始终贪婪接受第一个令牌。
- Medusa 架构在不同规模的 Vicuna 模型上进行实验,发现速度提升,但其在一些情况下效率增益会随批量大小增加而消失。
- 在 Codellama-13B-instruct 模型中增加预测头数量实现更大的加速。
- 推测解码技术有广阔的改进空间,随着 LLM 的广泛应用,推理优化是热门领域值得关注。
- 作者提供了支持方式,包括付费订阅、Buy me a coffee 和 GitHub Sponsor 。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。