光的语言模型推理速度

主要观点:在开发基于 Transformer 的语言模型推理的最小从头开始的快速 CUDA 实现[calm]过程中,考虑了推理过程的光速并测量其进展。文中涵盖了理论极限及其影响,包括推理机制、Mistral 的光速、理论界限的实用性、结论和附录等方面。

关键信息:

  • 语言模型生成 token 时一次一个,生成过程可按 token 建模,有矩阵向量乘法和注意力计算两种操作,且均为每个元素进行少量浮点运算,现代 CPU 和 GPU 运算能力远超内存读取能力,此类问题受带宽限制。
  • 以 Mistral 7B 为例,计算了其矩阵乘法所需读取的数据量及注意力计算所需读取的 KV-cache 数据量,得出推理的最小时间下限,如在 NVIDIA RTX 4090 上 14.2GB 数据读取约需 14.1ms/令牌,使用 8 位权重约需 7.0ms。
  • 理论界限有用处,实际达到该时间需高质量软件实现和能达到理论峰值带宽的硬件,可通过比较实际性能与理论极限评估实现质量,模型推理过程未充分利用 ALU 单元,批处理对 KV-cache 带宽帮助不大,带宽可作为硬件选择的关键估计器。
  • Mistral-7B 平衡良好,KV-cache 不是成本结构的关键部分原因之一是使用了组查询注意力(GQA),可减少 KV-cache 的大小和所需带宽,如 Command-R 模型不使用 GQA,在长上下文时 KV-cache 所需内存巨大,使用 GQA 则可减小。

重要细节:

  • AMD Ryzen 7950X 内存带宽 67GB/s,浮点运算 2735GFLOPS,FLOP:byte 比为 40:1;NVIDIA GeForce RTX 4090 内存带宽 1008GB/s,浮点运算 83TFLOPS,FLOP:byte 比为 82:1;NVIDIA H100 SXM 内存带宽 3350GB/s,浮点运算 67TFLOPS,对于矩阵乘法问题张量核心提供约 494TFLOPS 无稀疏性,FLOP:byte 比为 147:1,更小浮点格式下此比例更差。
  • Mistral 7B 各参数组成及计算,如嵌入矩阵 409632000=131M 参数,用于计算注意力相关向量 32(4096(12832+12882)+409612832)=1342M 参数等,总计约 7111M“活跃”参数,使用 FP16 时每次读取约 14.2GB 数据。
  • 在不同硬件上的实验结果,如在 NVIDIA RTX 4090 上,calm使用 16 位权重时 Mistral 7B 推理约 15.4ms/令牌,使用 8 位权重约 7.8ms/令牌;在 Apple M2 Air 上使用 CPU 推理时,calmllama.cpp仅达到理论 100GB/s 带宽的约 65%。
  • 附录中提到 Group query attention 可减少 KV-cache 大小和所需带宽,以 Mistral-7B 为例短上下文时优势不明显,但对于长上下文或多用户场景很重要,如不使用 GQA 的 Cohere 的 Command-R 模型在长上下文时 KV-cache 所需内存巨大,使用 GQA 则可减小。
阅读 78
0 条评论