Puzzle 10: Two Dimensional Convolution
A batched 2D convolution.
Uses one program id axis. Block size B0
represent the batches to process out of N0
.
Image x
is size is H
by W
with only 1 channel, and kernel k
is size KH
by KW
.
.. math::
z_{i, j, l} = \sum_{oj, ol}^{j+oj\le H, l+ol\le W} k_{oj,ol} \times x_{i,j + oj, l + ol}
\text{ for } i = 1\ldots N_0 \text{ for } j = 1\ldots H \text{ for } l = 1\ldots W
"""
def conv2d_spec(x: Float32[4, 8, 8], k: Float32[4, 4]) -> Float32[4, 8, 8]:
z = torch.zeros(4, 8, 8)
x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0)
# print(x.shape, k.shape)
for i in range(8):
for j in range(8):
z[:, i, j] = (k[None, :, :] * x[:, i : i + 4, j : j + 4]).sum(1).sum(1)
return z
@triton.jit
def conv2d_kernel(
x_ptr, k_ptr, z_ptr, N0, H, W, KH: tl.constexpr, KW: tl.constexpr, B0: tl.constexpr
):
# Finish me!
"""
@triton.jit 实现的2D卷积核函数
参数:
x_ptr: 输入张量指针
k_ptr: 卷积核指针
z_ptr: 输出张量指针
N0: 批量大小
H: 输入高度
W: 输入宽度
KH: 卷积核高度(编译时常量)
KW: 卷积核宽度(编译时常量)
B0: 块大小(编译时常量)
功能:
对输入张量执行2D卷积操作,结果存入输出张量
使用分块并行处理提高性能
"""
pid_0 = tl.program_id(0)
off_i = pid_0 * B0 + tl.arange(0, B0)
mask_i = off_i < N0
off_h = tl.arange(0, KH)
off_w = tl.arange(0, KW)
off_hw = off_h[:,None] * KW + off_w[None,:]
conv_kernel = tl.load(k_ptr + off_hw)
for j in tl.range(0,H):
for l in tl.range(0, W):
off_j_oj = j + off_h[None, :, None]
off_l_ol = l + off_w[None, None, :]
off_x = off_i * H * W + off_j_oj * W + off_l_ol
mask_x = (off_j_oj < H) & (off_l_ol < W)
x = tl.load(x_ptr + off_x, mask=mask_x)
z = tl.sum(x * conv_kernel[None, :])
off_z = off_i * H * W + j * W + l
tl.store(z_ptr + off_z, z)
return
unittest
Test basic convolution operation with small input tensor and kernel
import unittest import torch import triton from puzzles import conv2d_kernel class PuzzlesTest(unittest.TestCase): def test_basic_functionality_with_small_inputs(self): """Test basic convolution operation with small input tensor and kernel""" # Input parameters N0 = 1 H = 3 W = 3 KH = 2 KW = 2 B0 = 1 # Create input tensor and kernel with simple values x = torch.tensor([ [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0] ], device='cuda').reshape(1, H, W) k = torch.tensor([ [1.0, 0.0], [0.0, 1.0] ], device='cuda') # Expected output (manually computed convolution) expected_z = torch.tensor([ [1*1 + 2*0 + 4*0 + 5*1, 2*1 + 3*0 + 5*0 + 6*1], [4*1 + 5*0 + 7*0 + 8*1, 5*1 + 6*0 + 8*0 + 9*1] ], device='cuda').reshape(1, 2, 2) # Allocate output tensor z = torch.empty((N0, H - KH + 1, W - KW + 1), device='cuda') # Convert to pointers x_ptr = x.data_ptr() k_ptr = k.data_ptr() z_ptr = z.data_ptr() # Define grid grid = lambda meta: (triton.cdiv(N0, meta['B0']),) # Run the kernel conv2d_kernel[grid](x_ptr, k_ptr, z_ptr, N0, H, W, KH, KW, B0) # Check if results match self.assertTrue(torch.allclose(z, expected_z, rtol=1e-3, atol=1e-3), f"Expected:\n{expected_z}\nGot:\n{z}") if __name__ == '__main__': unittest.main()
Test when N0 is exactly divisible by B0
import unittest import torch import triton from puzzles import conv2d_kernel class PuzzlesTest(unittest.TestCase): def test_conv2d_kernel_full_block_processing(self): """Test conv2d_kernel when N0 is exactly divisible by B0""" # Input parameters N0 = 32 H = 5 W = 5 KH = 3 KW = 3 B0 = 32 # Create random input tensors x = torch.randn(N0, H, W, device='cuda') k = torch.randn(KH, KW, device='cuda') z = torch.empty(N0, H, W, device='cuda') # Convert to pointers x_ptr = x.data_ptr() k_ptr = k.data_ptr() z_ptr = z.data_ptr() # Define grid function grid = lambda meta: (triton.cdiv(N0, meta['B0']),) # Launch the kernel conv2d_kernel[grid](x_ptr, k_ptr, z_ptr, N0, H, W, KH, KW, B0=B0) # Compute expected output using PyTorch's conv2d # Note: We need to reshape and permute dimensions to match conv2d expectations x_4d = x.view(N0, 1, H, W) k_4d = k.view(1, 1, KH, KW) expected_z = torch.nn.functional.conv2d( x_4d, k_4d, padding=(KH//2, KW//2) ).squeeze(1) # Check if results match self.assertTrue(torch.allclose(z, expected_z, rtol=1e-3, atol=1e-3), "Output does not match expected result") if __name__ == '__main__': unittest.main()
Test with random input values
import unittest import torch import triton from puzzles import conv2d_kernel from tensor_type import Float32 class PuzzlesTest(unittest.TestCase): def test_conv2d_kernel_random_values(self): """Test conv2d_kernel with random input values""" # Setup test parameters N0 = 4 H = 10 W = 10 KH = 4 KW = 4 B0 = 2 # Generate random inputs torch.manual_seed(0) x = torch.rand((N0, H, W), device='cuda') - 0.5 k = torch.rand((KH, KW), device='cuda') - 0.5 # Allocate output tensor z = torch.empty((N0, H, W), device='cuda') # Convert to pointers x_ptr = x.data_ptr() k_ptr = k.data_ptr() z_ptr = z.data_ptr() # Compute reference result using PyTorch's conv2d # Note: We need to reshape and permute dimensions to match conv2d expectations x_4d = x.unsqueeze(1) # Add channel dimension k_4d = k.unsqueeze(0).unsqueeze(0) # Add out_channels and in_channels dimensions z_ref = torch.nn.functional.conv2d( x_4d, k_4d, padding=(KH//2, KW//2) ).squeeze(1) # Remove channel dimension # Run the kernel grid = lambda meta: (triton.cdiv(N0, meta['B0']),) conv2d_kernel[grid](x_ptr, k_ptr, z_ptr, N0, H, W, KH, KW, B0=B0) # Verify results self.assertTrue(torch.allclose(z, z_ref, rtol=1e-3, atol=1e-3), "Convolution results do not match reference") if __name__ == '__main__': unittest.main()
r"""
Puzzle 11: Matrix Multiplication
A blocked matrix multiplication.
Uses three program id axes. Block size B2
represent the batches to process out of N2
.
Block size B0
represent the rows of x
to process out of N0
. Block size B1
represent the cols
of y
to process out of N1
. The middle shape is MID
.
.. math::
z_{i, j, k} = \sum_{l} x_{i,j, l} \times y_{i, l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1
You are allowed to use tl.dot
which computes a smaller mat mul.
Hint: the main trick is that you can split a matmul into smaller parts.
.. math::
z_{i, j, k} = \sum_{l=1}^{L/2} x_{i,j, l} \times y_{i, l, k} + \sum_{l=L/2}^{L} x_{i,j, l} \times y_{i, l, k}
"""
def dot_spec(x: Float32[4, 32, 32], y: Float32[4, 32, 32]) -> Float32[4, 32, 32]:
return x @ y
@triton.jit
def dot_kernel(
x_ptr,
y_ptr,
z_ptr,
N0,
N1,
N2,
MID,
B0: tl.constexpr,
B1: tl.constexpr,
B2: tl.constexpr,
B_MID: tl.constexpr,
):
block_id_j = tl.program_id(0)
block_id_k = tl.program_id(1)
block_id_i = tl.program_id(2)
# Finish me!
off_i = block_id_i * B2 + tl.arange(0, B2)
off_j = block_id_j * B0 + tl.arange(0, B0)
off_k = block_id_k * B1 + tl.arange(0, B1)
mask_i = off_i < N2
mask_j = off_j < N0
mask_k = off_k < N1
z = tl.zeros((B2, B0, B1), dtype=tl.float32)
off_z = off_i[:,None,None] * N0 * N1 & off_j[None,:,None] * N1 & off_k[None,None,:]
mask_z = mask_i[:,None,None] & mask_j[None,:,None] & mask_k[None,None,:]
for i in range(0, MID, B_MID):
off_b = i + tl.arange(0, B_MID)
mask_b = off_b < MID
off_x = off_i[:,None,None] * N0 * MID + off_j[None,:,None] * MID + off_b[None,None,:]
off_y = off_i[:,None,None] * MID * N1 + off_b[None,:,None] * N1 + off_k[None,None,:]
mask_x = mask_i[:,None,None] & mask_j[None,:,None] & mask_b[None,None,:]
mask_y = mask_i[:,None,None] & mask_b[None,:,None] & mask_k[None,None,:]
x = tl.load(x_ptr + off_x, mask=mask_x)
y = tl.load(y_ptr + off_y, mask=mask_y)
z = tl.dot(x,y, allow_tf32=False)
tl.store(z_ptr + off_z, z, mask=mask_z)
return
Puzzle 12: Quantized Matrix Mult
When doing matrix multiplication with quantized neural networks a common strategy is to store the weight matrix in lower precision, with a shift and scale term.
For this problem our weight
will be stored in 4 bits. We can store FPINT
of these in a 32 bit integer. In addition for every group
weights in order we will store 1 scale
float value and 1 shift
4 bit value. We store these for the column of weight. The activation
s are stored separately in standard floats.
Mathematically it looks like.
.. math::
z_{j, k} = \sum_{l} sc_{j, \frac{l}{g}} (w_{j, l} - sh_{j, \frac{l}{g}}) \times y_{l, k}
\text{ for } j = 1\ldots N_0, k = 1\ldots N_1
Where g
is the number of groups (GROUP
).
However, it is a bit more complex since we need to also extract the 4-bit values into floats to begin.
Note:
- We don't consider batch size, i.e.
i
, in this puzzle. - Remember to unpack the
FPINT
values into separate 4-bit values. This contains some shape manipulation.
"""
FPINT = 32 // 4
GROUP = 8
def quant_dot_spec(
scale: Float32[32, 8],
offset: Int32[32,],
weight: Int32[32, 8],
activation: Float32[64, 32],
) -> Float32[32, 32]:
offset = offset.view(32, 1)
def extract(x):
over = torch.arange(8) * 4
mask = 2**4 - 1
return (x[..., None] >> over) & mask
scale = scale[..., None].expand(-1, 8, GROUP).contiguous().view(-1, 64)
offset = (
extract(offset)[..., None].expand(-1, 1, 8, GROUP).contiguous().view(-1, 64)
)
return (scale * (extract(weight).view(-1, 64) - offset)) @ activation
@triton.jit
def quant_dot_kernel(
scale_ptr,
offset_ptr,
weight_ptr,
activation_ptr,
z_ptr,
N0,
N1,
MID,
B0: tl.constexpr,
B1: tl.constexpr,
B_MID: tl.constexpr,
):
# Finish me!
block_id_j = tl.program_id(0)
block_id_k = tl.program_id(1)
# Finish me!
off_j = block_id_j * B0 + tl.arange(0, B0)
off_k = block_id_k * B1 + tl.arange(0, B1)
mask_j = off_j < N0
mask_k = off_k < N1
z = tl.zeros((B0, B1), dtype=tl.float32)
off_z = off_j[:, None] * N1 + off_k[None, :]
mask_z = mask_j[:, None] & mask_k[None, :]
for l in tl.range(0, MID, B_MID):
# load scale
off_l_div_g = tl.arange(0, B_MID // GROUP) + (l // GROUP)
mask_l_div_g = off_l_div_g < (MID // GROUP)
off_scale = off_j[:, None] * (MID // GROUP) + off_l_div_g[None, :]
# print(off_scale.shape)
mask_scale = mask_j[:, None] & mask_l_div_g[None, :]
scale = tl.load(scale_ptr + off_scale, mask=mask_scale)
# load shift (offset)
# (32,), each 32bits integer store FPINT(8)*4 shifts
shift = tl.load(offset_ptr + off_j, mask=mask_j)
# load weight
# note: our weight will be stored in 4bits.
off_weight_l = l + tl.arange(0, B_MID // FPINT)
mask_weight_l = off_weight_l < (MID // FPINT)
off_weight = off_j[:, None] * (MID // FPINT) + off_weight_l[None, :]
mask_weight = mask_j[:, None] & mask_weight_l[None, :]
weight = tl.load(weight_ptr + off_weight, mask=mask_weight)
# load activation as normal float
off_l = l + tl.arange(0, B_MID)
mask_l = off_l < MID
off_activation = off_l[:, None] * N1 + off_k[None, :]
mask_activation = mask_l[:, None] & mask_k[None, :]
activation = tl.load(activation_ptr + off_activation, mask=mask_activation)
# unpack weight and shift
BITS = 32 // FPINT
unpack_offs = tl.arange(0, FPINT) * BITS
unpack_upperbound_mask = (1 << BITS) - 1
unpacked_shift = (shift[:, None] >> unpack_offs) & unpack_upperbound_mask
unpacked_weight = (weight[:, :, None] >> unpack_offs) & unpack_upperbound_mask
# quant transform
# [BLOCK_J, 8, 1] * ([BLOCK_J, 8, 8] - [BLOCK_J, 8, 1])
transformed_weight = scale[:, :, None] * (
unpacked_weight - unpacked_shift[:, :, None]
)
# shape: [*, 64]
transformed_weight = transformed_weight.reshape(
unpacked_shift.shape[0], unpacked_shift.shape[-1] * FPINT
)
# compute
z += tl.dot(transformed_weight, activation)
tl.store(z_ptr + off_z, z, mask=mask_z)
return
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。