Puzzles 8: Long softmax
puzzles8是计算batch的softmax,题目如下:
Softmax of a batch of logits.
Uses one program block axis. Block size B0
represents the batch of x
of length N0
.
Block logit length T
. Process it B1 < T
elements at a time.
.. math::
z_{i, j} = \text{softmax}(x_{i,1} \ldots x_{i, T}) \text{ for } i = 1\ldots N_0
Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton
they recommend not using exp
but instead using exp2
. You need the identity
.. math::
\exp(x) = 2^{\log_2(e) x}
Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever.
Hint: you will find this identity useful:
.. math::
\exp(x_i - m) = \exp(x_i - m/2 - m/2) = \exp(x_i - m/ 2) / \exp(m/2)
"""
def softmax_spec(x: Float32[4, 200]) -> Float32[4, 200]:
x_max = x.max(1, keepdim=True)[0]
x = x - x_max
x_exp = x.exp()
return x_exp / x_exp.sum(1, keepdim=True)
然后这题需要提供两种解法,一种是暴力的解法,3个loop;另一种是聪明的解法,2个loop。先从暴力解法开始着手。
暴力解法思路
- 一个loop去取每一个行的最大值
- 每行中的每列减去对应行的最大值,顺便exp
- 一个loop去相加对应exp之后的值函数
- 一个loop计算最后的softmax
相关的triton接口
- torch.full(shape, value, dtype)可以直接初始化一个大小为shape,值为value的dtype向量,可以用来初始化极小值,用来取最大值,后面发现用tl.zeros也可以
解法
def softmax_kernel_brute_force(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
"""2 loops ver."""
block_id_i = tl.program_id(0)
log2_e = 1.44269504
# Finish me!i
offset_x = block_id_i * B0 + tl.arange(0,B0)
mask_x = offset_x < N0
row_max = tl.zeros(shape=[B0,1], dtype=tl.float32)
row_sum_exp = tl.zeros([B0, 1], dtype=tl.float32)
for idj in tl.range(0,T,B1):
offset_y = idj + tl.arange(0,B1)
mask_y = offset_y < T
offset_xy = offset_x[:,None] * T + offset_y[None,:]
mask_xy = mask_x[:,None] & mask_y[None,:]
block_value =tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
row_max = tl.maximum(row_max, tl.max(block_value,axis=1, keep_dims=True))
for idj in tl.range(0, T, B1):
offset_y = idj + tl.arange(0,B1)
mask_y = offset_y < T
offset_xy = offset_x[:,None] * T + offset_y[None,:]
mask_xy = mask_x[:,None] & mask_y[None,:]
block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
block_value -= row_max
row_sum_exp += tl.sum(exp_approx(block_value),axis=1, keep_dims=True)
for idj in tl.range(0, T, B1):
offset_y = idj + tl.arange(0,B1)
mask_y = offset_y < T
offset_xy = offset_x[:,None] * T + offset_y[None,:]
mask_xy = mask_x[:,None] & mask_y[None,:]
block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
softmax_value = exp_approx(block_value - row_max) / row_sum_exp
tl.store(z_ptr + offset_xy, softmax_value, mask_xy)
return
- 写得比较冗长,但是核心思路应该就是上面说的三个循环
两个循环思路
- 这个思路就类似online softmax
$$ \begin{aligned} & m_0 \leftarrow -\infty \\ & d_0 \leftarrow 0 \\ & \text{for } j \leftarrow 1, V \text{ do} \\ & \quad m_j \leftarrow \max(m_{j-1}, x_j) \\ & \quad d_j \leftarrow d_{j-1} \times e^{m_{j-1}-m_j} + e^{x_j-m_j} \quad \text{(Update row\_sum\_exp within the loop)} \\ & \text{end for} \\ & \text{for } i \leftarrow 1, V \text{ do} \\ & \quad y_i \leftarrow \frac{e^{x_i-m_V}}{d_V} \\ & \text{end for} \end{aligned} $$
解法
@triton.jit
def exp_approx(x):
return tl.exp2(1.44269504 * x)
@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
"""2 loops ver."""
block_id_i = tl.program_id(0)
log2_e = 1.44269504
# Finish me!i
offset_x = block_id_i * B0 + tl.arange(0,B0)
mask_x = offset_x < N0
row_max = tl.zeros(shape=[B0, 1],dtype=tl.float32)
row_sum_exp = tl.zeros([B0, 1], dtype=tl.float32)
for idj in tl.range(0, T, B1):
offset_y = idj + tl.arange(0,B1)
mask_y = offset_y < T
offset_xy = offset_x[:,None] * T + offset_y[None,:]
mask_xy = mask_x[:,None] & mask_y[None,:]
block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
tmp_row_max = row_max
tmp_row_max = tl.maximum(tl.max(block_value, axis=1, keep_dims=True), tmp_row_max)
row_sum_exp = row_sum_exp * exp_approx(row_max - tmp_row_max) + tl.sum(exp_approx(block_value - tmp_row_max),axis=1,keep_dims=True)
row_max = tmp_row_max
for idj in tl.range(0, T, B1):
offset_y = idj + tl.arange(0,B1)
mask_y = offset_y < T
offset_xy = offset_x[:,None] * T + offset_y[None,:]
mask_xy = mask_x[:,None] & mask_y[None,:]
block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
z_value = exp_approx(block_value - row_max) / row_sum_exp
tl.store(z_ptr + offset_xy, z_value, mask_xy)
return
Puzzle 9: Simple FlashAttention
A scalar version of FlashAttention.
Uses zero programs. Block size B0
represent the batches of q
to process out of N0
. Sequence length is T
. Process it B1 < T
elements (k
, v
) at a time for some B1
.
.. math::
z_{i} = \sum_{j=1}^{T} \text{softmax}(q_i k_1, \ldots, q_i k_T)_j v_{j} \text{ for } i = 1\ldots N_0
This can be done in 1 loop using a similar trick from the last puzzle.
Hint: Use tl.where
to mask q dot k
to -inf to avoid overflow (NaN).
这个类似flash attention v1了,one pass
Flash attention v1的完整递推公式
$$ \mathbf{ \begin{aligned} x_i &\leftarrow Q[k,:] \cdot K^T[:,i] \\ m_i &\leftarrow \max(m_{i-1}, x_i) \\ d_i' &\leftarrow d_{i-1}' \cdot e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ O_i' &\leftarrow O_{i-1}' \cdot \frac{d_{i-1}'}{d_i'} \cdot e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d_i'} \cdot V[i,:] \\ \end{aligned} } $$
最终输出:
$$ O[k,:] \leftarrow O_N' $$
其中:
- $Q[k,:]$ 是 $Q$ 矩阵的第 $k$ 行向量。
- $K^T[:,i]$ 是 $K^T$ 矩阵的第 $i$ 列向量。
- $x_i$是预 softmax 的 logits 值。
- $ m_i $ 是累积的最大值。
- $d_i'$ 是累积的指数和。
- $O_i'$ 是部分输出的累积值。
- $V[i,:]$ 是 $ V $ 矩阵的第 $ i $ 行向量。
- $ O[k,:]$ 是输出矩阵的第 $k $ 行向量。
解法
@triton.jit
def myexp(x):
return tl.exp2(1.44269504 * x)
@triton.jit
def flashatt_kernel(
q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr
):
block_id_i = tl.program_id(0)
log2_e = 1.44269504
# Finish me!
off_i = block_id_i * B0 + tl.arange(0, B0)
mask_i = off_i < N0
inf = 1.0e6
# Need `other`!!!
q = tl.load(q_ptr + off_i, mask=mask_i)
# The variable names of Triton's offcial FlashAttention tutorial
# is attached here for reference.
# Our variable names are consistent with Puzzle 8.
# l_i
exp_sum = tl.zeros((B0,), dtype=tl.float32)
# m_i
qk_max = tl.full((B0,), -inf, dtype=tl.float32)
z = tl.zeros((B0,), dtype=tl.float32)
for id_j in tl.range(0, T, B1):
off_j = id_j + tl.arange(0, B1)
mask_j = off_j < T
mask_ij = mask_i[:, None] & mask_j[None, :]
k = tl.load(k_ptr + off_j, mask=mask_j)
qk = q[:, None] * k[None, :] + tl.where(mask_ij, 0, -inf)
# print(qk.shape)
# m_ij
new_max = tl.maximum(tl.max(qk, axis=1), qk_max)
qk_exp = myexp(qk - new_max[:, None])
# alpha
factor = myexp(qk_max - new_max)
# l_ij
new_exp_sum = exp_sum * factor + tl.sum(qk_exp, axis=1)
v = tl.load(v_ptr + off_j, mask=mask_j, other=0.0)
z = z * factor + tl.sum(qk_exp * v[None, :], axis=1)
qk_max = new_max
exp_sum = new_exp_sum
z = z / exp_sum
tl.store(z_ptr + off_i, z, mask=mask_i)
return
Reference
- https://zhuanlan.zhihu.com/p/20269643126
online softmax
- https://zhuanlan.zhihu.com/p/1896616212974261259
- 【triton】triton调试工具与方法-CSDN博客
- https://zhuanlan.zhihu.com/p/669926191
Flash Attention
本文由博客一文多发平台 OpenWrite 发布!
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。