背景

笔者最近在工作中需要用到一些高性能计算的优化,于是准备着手系统性进行学习。有大佬建议先从triton学起,并且推荐了triton puzzles和triton的tutorial作为入门资料。以下是我练习triton puzzles时对一些解法的分析,记录一下作为心得。

练习题库git

https://github.com/SiriusNEO/Triton-Puzzles-Lite

Puzzle 1:Constant Add(常数加法)

任务:将常数 10 加到一个向量的每个元素上。

问题分析:这是最基础的操作,重点在于理解 Triton 中的指针运算和数据加载。

@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    off_x = tl.arange(0, B0)
    x = tl.load(x_ptr + off_x)
    z = x + 10.0  # Add the constant
    tl.store(z_ptr + off_x, z)

关键点:

  1. 使用 tl.arange 创建索引范围
  2. 通过指针偏移加载数据(x_ptr + off_x)
  3. 完成计算后,将结果存储回目标指针位置

Puzzle 2:Constant Add Block(分块常数加法)

任务:在向量长度大于块大小时,分块添加常数。

问题分析:引入了分块处理的概念,适用于大规模数据计算。

@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    off_x = tl.arange(0, B0) + tl.program_id(0) * B0
    mask = off_x < N0
    x = tl.load(x_ptr + off_x, mask=mask)
    z = x + 10.0
    tl.store(z_ptr + off_x, z, mask=mask)

关键点:

  1. 使用 tl.program_id 获取当前块的 ID
  2. 构建掩码(mask)确保不越界访问
  3. 在加载和存储时应用掩码

Puzzle 3:Outer Vector Add(外积加法)

任务:计算两个向量的外积加法。

问题分析:涉及二维矩阵的生成和操作,需要处理两个维度的索引。

@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    i = tl.arange(0, B0)[:, None]
    j = tl.arange(0, B1)[None, :]
    x = tl.load(x_ptr + i)
    y = tl.load(y_ptr + j)
    z = x + y
    tl.store(z_ptr + i * N1 + j, z)

关键点:

  1. 使用二维索引范围(i 和 j)
  2. 两个向量的广播相加
  3. 目标指针的二维索引转换为一维

Puzzle 4:Outer Vector Add Block(分块外积加法)

任务:在向量长度大于块大小时,分块计算外积加法。

问题分析:结合分块处理和二维矩阵操作,提高计算效率。

@triton.jit
def add_vec_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    block_id_x = tl.program_id(0)
    block_id_y = tl.program_id(1)
    off_x = block_id_x * B0 + tl.arange(0, B0)
    off_y = block_id_y * B1 + tl.arange(0, B1)
    mask_x = off_x < N0
    mask_y = off_y < N1
    x = tl.load(x_ptr + off_x, mask=mask_x)
    y = tl.load(y_ptr + off_y, mask=mask_y)
    z = x + y
    tl.store(z_ptr + off_x * N1 + off_y, z, mask=mask_x & mask_y)

关键点:

  1. 多块并行处理
  2. 二维掩码的构建和应用

Puzzle 5:Fused Outer Multiplication(融合外积乘法)

任务:计算两个向量的外积乘法,并应用 ReLU 激活。

问题分析:在 Puzzle 4 的基础上引入非线性变换。

@triton.jit
def mul_relu_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    block_id_x = tl.program_id(0)
    block_id_y = tl.program_id(1)
    off_x = block_id_x * B0 + tl.arange(0, B0)
    off_y = block_id_y * B1 + tl.arange(0, B1)
    mask_x = off_x < N0
    mask_y = off_y < N1
    x = tl.load(x_ptr + off_x, mask=mask_x)
    y = tl.load(y_ptr + off_y, mask=mask_y)
    z = x * y
    z = tl.where(z > 0, z, 0)  # Apply ReLU
    tl.store(z_ptr + off_x * N1 + off_y, z, mask=mask_x & mask_y)

关键点:

  1. 元素级乘法操作
  2. 使用 tl.where 实现条件操作

Puzzle 6:Fused Outer Multiplication - Backwards(反向传播)

任务:计算外积乘法的反向传播。

问题分析:涉及梯度计算和链式法则的应用。

@triton.jit
def mul_relu_block_back_kernel(x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    block_id_x = tl.program_id(0)
    block_id_y = tl.program_id(1)
    off_x = block_id_x * B0 + tl.arange(0, B0)
    off_y = block_id_y * B1 + tl.arange(0, B1)
    mask_x = off_x < N0
    mask_y = off_y < N1
    x = tl.load(x_ptr + off_x, mask=mask_x)
    y = tl.load(y_ptr + off_y, mask=mask_y)
    dz = tl.load(dz_ptr + off_x * N1 + off_y, mask=mask_x & mask_y)
    dx = dz * y
    tl.store(dx_ptr + off_x * N1 + off_y, dx, mask=mask_x & mask_y)

关键点:

  1. 梯度的计算和传播
  2. 结合前向计算的中间结果

Puzzle 7:Long Sum(长向量求和)

任务:计算批量数据的求和。

问题分析:涉及循环操作和块间同步,是高性能计算中的常见操作。

@triton.jit
def sum_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    pid = tl.program_id(0)
    block_start = pid * B0
    offsets = block_start + tl.arange(0, B0)
    mask = offsets < N0
    x = tl.load(x_ptr + offsets * T + tl.arange(0, T), mask=mask)
    z = tl.sum(x, 1)
    tl.store(z_ptr + offsets, z, mask=mask)

关键点:

  1. 循环求和操作
  2. 块间的数据独立性

一些想法

  1. mask的操作值得学习
  2. 争取尽快做完后面的题目
本文由博客一文多发平台 OpenWrite 发布!

jay_kay
0 声望0 粉丝

永远不要忘记,我依然爱着你