Puzzles 8: Long softmax

puzzles8是计算batch的softmax,题目如下:

Softmax of a batch of logits.

Uses one program block axis. Block size B0 represents the batch of x of length N0.

Block logit length T.   Process it B1 < T elements at a time.  

.. math::

    z_{i, j} = \text{softmax}(x_{i,1} \ldots x_{i, T}) \text{ for } i = 1\ldots N_0

Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton 

they recommend not using exp but instead using exp2. You need the identity

.. math::

    \exp(x) = 2^{\log_2(e) x}

Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever. 

Hint: you will find this identity useful:

.. math::

    \exp(x_i - m) =  \exp(x_i - m/2 - m/2) = \exp(x_i - m/ 2) /  \exp(m/2)

"""

def softmax_spec(x: Float32[4, 200]) -> Float32[4, 200]:

    x_max = x.max(1, keepdim=True)[0]

    x = x - x_max

    x_exp = x.exp()

    return x_exp / x_exp.sum(1, keepdim=True)

然后这题需要提供两种解法,一种是暴力的解法,3个loop;另一种是聪明的解法,2个loop。先从暴力解法开始着手。

暴力解法思路

  1. 一个loop去取每一个行的最大值
  2. 每行中的每列减去对应行的最大值,顺便exp
  3. 一个loop去相加对应exp之后的值函数
  4. 一个loop计算最后的softmax

相关的triton接口

  1. torch.full(shape, value, dtype)可以直接初始化一个大小为shape,值为value的dtype向量,可以用来初始化极小值,用来取最大值,后面发现用tl.zeros也可以

解法

def softmax_kernel_brute_force(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    """2 loops ver."""
    block_id_i = tl.program_id(0)
    log2_e = 1.44269504
    # Finish me!i
    offset_x = block_id_i * B0 + tl.arange(0,B0)
    mask_x = offset_x < N0

    row_max = tl.zeros(shape=[B0,1], dtype=tl.float32)
    row_sum_exp = tl.zeros([B0, 1], dtype=tl.float32)

    for idj in tl.range(0,T,B1):
        offset_y = idj + tl.arange(0,B1)
        mask_y = offset_y < T
        offset_xy = offset_x[:,None] * T + offset_y[None,:]
        mask_xy = mask_x[:,None] & mask_y[None,:]
        block_value =tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
        row_max = tl.maximum(row_max, tl.max(block_value,axis=1, keep_dims=True))


    for idj in tl.range(0, T, B1):
        offset_y = idj + tl.arange(0,B1)
        mask_y = offset_y < T
        offset_xy = offset_x[:,None] * T + offset_y[None,:]
        mask_xy = mask_x[:,None] & mask_y[None,:]
        block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
        block_value -= row_max
        row_sum_exp += tl.sum(exp_approx(block_value),axis=1, keep_dims=True)

    for idj in tl.range(0, T, B1):
        offset_y = idj + tl.arange(0,B1)
        mask_y = offset_y < T
        offset_xy = offset_x[:,None] * T + offset_y[None,:]
        mask_xy = mask_x[:,None] & mask_y[None,:]
        block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
        softmax_value = exp_approx(block_value - row_max) / row_sum_exp
        tl.store(z_ptr + offset_xy, softmax_value, mask_xy)
    return
  1. 写得比较冗长,但是核心思路应该就是上面说的三个循环

两个循环思路

  1. 这个思路就类似online softmax

$$ \begin{aligned} & m_0 \leftarrow -\infty \\ & d_0 \leftarrow 0 \\ & \text{for } j \leftarrow 1, V \text{ do} \\ & \quad m_j \leftarrow \max(m_{j-1}, x_j) \\ & \quad d_j \leftarrow d_{j-1} \times e^{m_{j-1}-m_j} + e^{x_j-m_j} \quad \text{(Update row\_sum\_exp within the loop)} \\ & \text{end for} \\ & \text{for } i \leftarrow 1, V \text{ do} \\ & \quad y_i \leftarrow \frac{e^{x_i-m_V}}{d_V} \\ & \text{end for} \end{aligned} $$

解法

@triton.jit
def exp_approx(x):
    return tl.exp2(1.44269504 * x)



@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    """2 loops ver."""
    block_id_i = tl.program_id(0)
    log2_e = 1.44269504
    # Finish me!i
    offset_x = block_id_i * B0 + tl.arange(0,B0)
    mask_x = offset_x < N0

    row_max = tl.zeros(shape=[B0, 1],dtype=tl.float32)
    row_sum_exp = tl.zeros([B0, 1], dtype=tl.float32)

    for idj in tl.range(0, T, B1):
        offset_y = idj + tl.arange(0,B1)
        mask_y = offset_y < T
        offset_xy = offset_x[:,None] * T + offset_y[None,:]
        mask_xy = mask_x[:,None] & mask_y[None,:]
        block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))

        tmp_row_max = row_max 
        tmp_row_max = tl.maximum(tl.max(block_value, axis=1, keep_dims=True), tmp_row_max)
        row_sum_exp = row_sum_exp * exp_approx(row_max - tmp_row_max) + tl.sum(exp_approx(block_value - tmp_row_max),axis=1,keep_dims=True)

        row_max = tmp_row_max

    for idj in tl.range(0, T, B1):
        offset_y = idj + tl.arange(0,B1)
        mask_y = offset_y < T
        offset_xy = offset_x[:,None] * T + offset_y[None,:]
        mask_xy = mask_x[:,None] & mask_y[None,:]
        block_value = tl.load(x_ptr + offset_xy, mask_xy, other=float('-inf'))
        z_value = exp_approx(block_value - row_max) / row_sum_exp
        tl.store(z_ptr + offset_xy, z_value, mask_xy)

    return 

Puzzle 9: Simple FlashAttention

A scalar version of FlashAttention.

Uses zero programs. Block size B0 represent the batches of q to process out of N0. Sequence length is T. Process it B1 < T elements (k, v) at a time for some B1.

.. math::
z_{i} = \sum_{j=1}^{T} \text{softmax}(q_i k_1, \ldots, q_i k_T)_j v_{j} \text{ for } i = 1\ldots N_0

This can be done in 1 loop using a similar trick from the last puzzle.

Hint: Use tl.where to mask q dot k to -inf to avoid overflow (NaN).

这个类似flash attention v1了,one pass

Flash attention v1的完整递推公式

$$ \mathbf{ \begin{aligned} x_i &\leftarrow Q[k,:] \cdot K^T[:,i] \\ m_i &\leftarrow \max(m_{i-1}, x_i) \\ d_i' &\leftarrow d_{i-1}' \cdot e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ O_i' &\leftarrow O_{i-1}' \cdot \frac{d_{i-1}'}{d_i'} \cdot e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d_i'} \cdot V[i,:] \\ \end{aligned} } $$

最终输出:

$$ O[k,:] \leftarrow O_N' $$

其中:

  • $Q[k,:]$ 是 $Q$ 矩阵的第 $k$ 行向量。
  • $K^T[:,i]$ 是 $K^T$ 矩阵的第 $i$ 列向量。
  • $x_i$是预 softmax 的 logits 值。
  • $ m_i $ 是累积的最大值。
  • $d_i'$ 是累积的指数和。
  • $O_i'$ 是部分输出的累积值。
  • $V[i,:]$ 是 $ V $ 矩阵的第 $ i $ 行向量。
  • $ O[k,:]$ 是输出矩阵的第 $k $ 行向量。

解法

@triton.jit
def myexp(x):
    return tl.exp2(1.44269504 * x)

@triton.jit
def flashatt_kernel(
    q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr
):
    block_id_i = tl.program_id(0)
    log2_e = 1.44269504
    # Finish me!
    
    off_i = block_id_i * B0 + tl.arange(0, B0)
    mask_i = off_i < N0
    inf = 1.0e6

    # Need `other`!!!
    q = tl.load(q_ptr + off_i, mask=mask_i)

    # The variable names of Triton's offcial FlashAttention tutorial
    # is attached here for reference.
    # Our variable names are consistent with Puzzle 8.

    # l_i
    exp_sum = tl.zeros((B0,), dtype=tl.float32)
    # m_i
    qk_max = tl.full((B0,), -inf, dtype=tl.float32)
    z = tl.zeros((B0,), dtype=tl.float32)

    for id_j in tl.range(0, T, B1):
        off_j = id_j + tl.arange(0, B1)
        mask_j = off_j < T
        mask_ij = mask_i[:, None] & mask_j[None, :]

        k = tl.load(k_ptr + off_j, mask=mask_j)
        qk = q[:, None] * k[None, :] + tl.where(mask_ij, 0, -inf)
        # print(qk.shape)

        # m_ij
        new_max = tl.maximum(tl.max(qk, axis=1), qk_max)
        qk_exp = myexp(qk - new_max[:, None])
        # alpha
        factor = myexp(qk_max - new_max)
        # l_ij
        new_exp_sum = exp_sum * factor + tl.sum(qk_exp, axis=1)
        v = tl.load(v_ptr + off_j, mask=mask_j, other=0.0)
        z = z * factor + tl.sum(qk_exp * v[None, :], axis=1)

        qk_max = new_max
        exp_sum = new_exp_sum

    z = z / exp_sum
    tl.store(z_ptr + off_i, z, mask=mask_i)
    return

Reference

  1. https://zhuanlan.zhihu.com/p/20269643126
  2. online softmax

    1. https://arxiv.org/pdf/1805.02867
    2. https://zhuanlan.zhihu.com/p/11656282335
  3. https://zhuanlan.zhihu.com/p/1896616212974261259
  4. 【triton】triton调试工具与方法-CSDN博客
  5. https://zhuanlan.zhihu.com/p/669926191
  6. Flash Attention

    1. https://zhuanlan.zhihu.com/p/669926191
    2. https://zhuanlan.zhihu.com/p/668888063
    本文由博客一文多发平台 OpenWrite 发布!

jay_kay
0 声望0 粉丝

永远不要忘记,我依然爱着你