背景
笔者最近在工作中需要用到一些高性能计算的优化,于是准备着手系统性进行学习。有大佬建议先从triton学起,并且推荐了triton puzzles和triton的tutorial作为入门资料。以下是我练习triton puzzles时对一些解法的分析,记录一下作为心得。
练习题库git
https://github.com/SiriusNEO/Triton-Puzzles-Lite
Puzzle 1:Constant Add(常数加法)
任务:将常数 10 加到一个向量的每个元素上。
问题分析:这是最基础的操作,重点在于理解 Triton 中的指针运算和数据加载。
@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
off_x = tl.arange(0, B0)
x = tl.load(x_ptr + off_x)
z = x + 10.0 # Add the constant
tl.store(z_ptr + off_x, z)
关键点:
- 使用 tl.arange 创建索引范围
- 通过指针偏移加载数据(x_ptr + off_x)
- 完成计算后,将结果存储回目标指针位置
Puzzle 2:Constant Add Block(分块常数加法)
任务:在向量长度大于块大小时,分块添加常数。
问题分析:引入了分块处理的概念,适用于大规模数据计算。
@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
off_x = tl.arange(0, B0) + tl.program_id(0) * B0
mask = off_x < N0
x = tl.load(x_ptr + off_x, mask=mask)
z = x + 10.0
tl.store(z_ptr + off_x, z, mask=mask)
关键点:
- 使用 tl.program_id 获取当前块的 ID
- 构建掩码(mask)确保不越界访问
- 在加载和存储时应用掩码
Puzzle 3:Outer Vector Add(外积加法)
任务:计算两个向量的外积加法。
问题分析:涉及二维矩阵的生成和操作,需要处理两个维度的索引。
@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
i = tl.arange(0, B0)[:, None]
j = tl.arange(0, B1)[None, :]
x = tl.load(x_ptr + i)
y = tl.load(y_ptr + j)
z = x + y
tl.store(z_ptr + i * N1 + j, z)
关键点:
- 使用二维索引范围(i 和 j)
- 两个向量的广播相加
- 目标指针的二维索引转换为一维
Puzzle 4:Outer Vector Add Block(分块外积加法)
任务:在向量长度大于块大小时,分块计算外积加法。
问题分析:结合分块处理和二维矩阵操作,提高计算效率。
@triton.jit
def add_vec_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = block_id_x * B0 + tl.arange(0, B0)
off_y = block_id_y * B1 + tl.arange(0, B1)
mask_x = off_x < N0
mask_y = off_y < N1
x = tl.load(x_ptr + off_x, mask=mask_x)
y = tl.load(y_ptr + off_y, mask=mask_y)
z = x + y
tl.store(z_ptr + off_x * N1 + off_y, z, mask=mask_x & mask_y)
关键点:
- 多块并行处理
- 二维掩码的构建和应用
Puzzle 5:Fused Outer Multiplication(融合外积乘法)
任务:计算两个向量的外积乘法,并应用 ReLU 激活。
问题分析:在 Puzzle 4 的基础上引入非线性变换。
@triton.jit
def mul_relu_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = block_id_x * B0 + tl.arange(0, B0)
off_y = block_id_y * B1 + tl.arange(0, B1)
mask_x = off_x < N0
mask_y = off_y < N1
x = tl.load(x_ptr + off_x, mask=mask_x)
y = tl.load(y_ptr + off_y, mask=mask_y)
z = x * y
z = tl.where(z > 0, z, 0) # Apply ReLU
tl.store(z_ptr + off_x * N1 + off_y, z, mask=mask_x & mask_y)
关键点:
- 元素级乘法操作
- 使用 tl.where 实现条件操作
Puzzle 6:Fused Outer Multiplication - Backwards(反向传播)
任务:计算外积乘法的反向传播。
问题分析:涉及梯度计算和链式法则的应用。
@triton.jit
def mul_relu_block_back_kernel(x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = block_id_x * B0 + tl.arange(0, B0)
off_y = block_id_y * B1 + tl.arange(0, B1)
mask_x = off_x < N0
mask_y = off_y < N1
x = tl.load(x_ptr + off_x, mask=mask_x)
y = tl.load(y_ptr + off_y, mask=mask_y)
dz = tl.load(dz_ptr + off_x * N1 + off_y, mask=mask_x & mask_y)
dx = dz * y
tl.store(dx_ptr + off_x * N1 + off_y, dx, mask=mask_x & mask_y)
关键点:
- 梯度的计算和传播
- 结合前向计算的中间结果
Puzzle 7:Long Sum(长向量求和)
任务:计算批量数据的求和。
问题分析:涉及循环操作和块间同步,是高性能计算中的常见操作。
@triton.jit
def sum_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
pid = tl.program_id(0)
block_start = pid * B0
offsets = block_start + tl.arange(0, B0)
mask = offsets < N0
x = tl.load(x_ptr + offsets * T + tl.arange(0, T), mask=mask)
z = tl.sum(x, 1)
tl.store(z_ptr + offsets, z, mask=mask)
关键点:
- 循环求和操作
- 块间的数据独立性
一些想法
- mask的操作值得学习
- 争取尽快做完后面的题目
本文由博客一文多发平台 OpenWrite 发布!
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。