Puzzle 10: Two Dimensional Convolution

A batched 2D convolution.

Uses one program id axis. Block size B0 represent the batches to process out of N0.
Image x is size is H by W with only 1 channel, and kernel k is size KH by KW.

.. math::
z_{i, j, l} = \sum_{oj, ol}^{j+oj\le H, l+ol\le W} k_{oj,ol} \times x_{i,j + oj, l + ol}
\text{ for } i = 1\ldots N_0 \text{ for } j = 1\ldots H \text{ for } l = 1\ldots W
"""

def conv2d_spec(x: Float32[4, 8, 8], k: Float32[4, 4]) -> Float32[4, 8, 8]:
 z = torch.zeros(4, 8, 8)
 x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0)
 # print(x.shape, k.shape)
 for i in range(8):
 for j in range(8):
 z[:, i, j] = (k[None, :, :] * x[:, i : i + 4, j : j + 4]).sum(1).sum(1)
 return z
@triton.jit
def conv2d_kernel(
 x_ptr, k_ptr, z_ptr, N0, H, W, KH: tl.constexpr, KW: tl.constexpr, B0: tl.constexpr
):
 # Finish me!
 """
 @triton.jit 实现的2D卷积核函数
参数:
    x_ptr: 输入张量指针
    k_ptr: 卷积核指针
    z_ptr: 输出张量指针
    N0: 批量大小
    H: 输入高度
    W: 输入宽度
    KH: 卷积核高度(编译时常量)
    KW: 卷积核宽度(编译时常量)
    B0: 块大小(编译时常量)

功能:
    对输入张量执行2D卷积操作,结果存入输出张量
    使用分块并行处理提高性能
"""
pid_0 = tl.program_id(0)
off_i = pid_0 * B0 + tl.arange(0, B0)
mask_i = off_i < N0

off_h = tl.arange(0, KH)
off_w = tl.arange(0, KW)
off_hw = off_h[:,None] * KW + off_w[None,:]
conv_kernel = tl.load(k_ptr + off_hw)

for j in tl.range(0,H):
    for l in tl.range(0, W):
        off_j_oj = j + off_h[None, :, None]
        off_l_ol = l + off_w[None, None, :]
        off_x = off_i * H * W + off_j_oj * W + off_l_ol
        mask_x = (off_j_oj < H) & (off_l_ol < W) 
        x = tl.load(x_ptr + off_x, mask=mask_x)

        z = tl.sum(x * conv_kernel[None, :])
        off_z = off_i * H * W + j * W + l
        tl.store(z_ptr + off_z, z)

return

unittest

  • Test basic convolution operation with small input tensor and kernel

    import unittest
    import torch
    import triton
    from puzzles import conv2d_kernel
    
    class PuzzlesTest(unittest.TestCase):
        def test_basic_functionality_with_small_inputs(self):
            """Test basic convolution operation with small input tensor and kernel"""
            # Input parameters
            N0 = 1
            H = 3
            W = 3
            KH = 2
            KW = 2
            B0 = 1
    
            # Create input tensor and kernel with simple values
            x = torch.tensor([
                [1.0, 2.0, 3.0],
                [4.0, 5.0, 6.0],
                [7.0, 8.0, 9.0]
            ], device='cuda').reshape(1, H, W)
    
            k = torch.tensor([
                [1.0, 0.0],
                [0.0, 1.0]
            ], device='cuda')
    
            # Expected output (manually computed convolution)
            expected_z = torch.tensor([
                [1*1 + 2*0 + 4*0 + 5*1, 2*1 + 3*0 + 5*0 + 6*1],
                [4*1 + 5*0 + 7*0 + 8*1, 5*1 + 6*0 + 8*0 + 9*1]
            ], device='cuda').reshape(1, 2, 2)
    
            # Allocate output tensor
            z = torch.empty((N0, H - KH + 1, W - KW + 1), device='cuda')
    
            # Convert to pointers
            x_ptr = x.data_ptr()
            k_ptr = k.data_ptr()
            z_ptr = z.data_ptr()
    
            # Define grid
            grid = lambda meta: (triton.cdiv(N0, meta['B0']),)
    
            # Run the kernel
            conv2d_kernel[grid](x_ptr, k_ptr, z_ptr, N0, H, W, KH, KW, B0)
    
            # Check if results match
            self.assertTrue(torch.allclose(z, expected_z, rtol=1e-3, atol=1e-3),
                           f"Expected:\n{expected_z}\nGot:\n{z}")
    
    if __name__ == '__main__':
        unittest.main()
  • Test when N0 is exactly divisible by B0

    import unittest
    import torch
    import triton
    from puzzles import conv2d_kernel
    
    class PuzzlesTest(unittest.TestCase):
        def test_conv2d_kernel_full_block_processing(self):
            """Test conv2d_kernel when N0 is exactly divisible by B0"""
            # Input parameters
            N0 = 32
            H = 5
            W = 5
            KH = 3
            KW = 3
            B0 = 32
            
            # Create random input tensors
            x = torch.randn(N0, H, W, device='cuda')
            k = torch.randn(KH, KW, device='cuda')
            z = torch.empty(N0, H, W, device='cuda')
            
            # Convert to pointers
            x_ptr = x.data_ptr()
            k_ptr = k.data_ptr()
            z_ptr = z.data_ptr()
            
            # Define grid function
            grid = lambda meta: (triton.cdiv(N0, meta['B0']),)
            
            # Launch the kernel
            conv2d_kernel[grid](x_ptr, k_ptr, z_ptr, N0, H, W, KH, KW, B0=B0)
            
            # Compute expected output using PyTorch's conv2d
            # Note: We need to reshape and permute dimensions to match conv2d expectations
            x_4d = x.view(N0, 1, H, W)
            k_4d = k.view(1, 1, KH, KW)
            expected_z = torch.nn.functional.conv2d(
                x_4d, k_4d, padding=(KH//2, KW//2)
            ).squeeze(1)
            
            # Check if results match
            self.assertTrue(torch.allclose(z, expected_z, rtol=1e-3, atol=1e-3),
                            "Output does not match expected result")
    
    if __name__ == '__main__':
        unittest.main()
  • Test with random input values

    import unittest
    import torch
    import triton
    from puzzles import conv2d_kernel
    from tensor_type import Float32
    
    class PuzzlesTest(unittest.TestCase):
        def test_conv2d_kernel_random_values(self):
            """Test conv2d_kernel with random input values"""
            # Setup test parameters
            N0 = 4
            H = 10
            W = 10
            KH = 4
            KW = 4
            B0 = 2
            
            # Generate random inputs
            torch.manual_seed(0)
            x = torch.rand((N0, H, W), device='cuda') - 0.5
            k = torch.rand((KH, KW), device='cuda') - 0.5
            
            # Allocate output tensor
            z = torch.empty((N0, H, W), device='cuda')
            
            # Convert to pointers
            x_ptr = x.data_ptr()
            k_ptr = k.data_ptr()
            z_ptr = z.data_ptr()
            
            # Compute reference result using PyTorch's conv2d
            # Note: We need to reshape and permute dimensions to match conv2d expectations
            x_4d = x.unsqueeze(1)  # Add channel dimension
            k_4d = k.unsqueeze(0).unsqueeze(0)  # Add out_channels and in_channels dimensions
            z_ref = torch.nn.functional.conv2d(
                x_4d, 
                k_4d, 
                padding=(KH//2, KW//2)
            ).squeeze(1)  # Remove channel dimension
            
            # Run the kernel
            grid = lambda meta: (triton.cdiv(N0, meta['B0']),)
            conv2d_kernel[grid](x_ptr, k_ptr, z_ptr, N0, H, W, KH, KW, B0=B0)
            
            # Verify results
            self.assertTrue(torch.allclose(z, z_ref, rtol=1e-3, atol=1e-3),
                            "Convolution results do not match reference")
    
    if __name__ == '__main__':
        unittest.main()

r"""

Puzzle 11: Matrix Multiplication

A blocked matrix multiplication.

Uses three program id axes. Block size B2 represent the batches to process out of N2.
Block size B0 represent the rows of x to process out of N0. Block size B1 represent the cols
of y to process out of N1. The middle shape is MID.

.. math::
z_{i, j, k} = \sum_{l} x_{i,j, l} \times y_{i, l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1

You are allowed to use tl.dot which computes a smaller mat mul.

Hint: the main trick is that you can split a matmul into smaller parts.

.. math::
z_{i, j, k} = \sum_{l=1}^{L/2} x_{i,j, l} \times y_{i, l, k} + \sum_{l=L/2}^{L} x_{i,j, l} \times y_{i, l, k}
"""

def dot_spec(x: Float32[4, 32, 32], y: Float32[4, 32, 32]) -> Float32[4, 32, 32]:
 return x @ y
@triton.jit
def dot_kernel(
    x_ptr,
    y_ptr,
    z_ptr,
    N0,
    N1,
    N2,
    MID,
    B0: tl.constexpr,
    B1: tl.constexpr,
    B2: tl.constexpr,
    B_MID: tl.constexpr,
):
    block_id_j = tl.program_id(0)
    block_id_k = tl.program_id(1)
    block_id_i = tl.program_id(2)
    # Finish me!
    off_i = block_id_i * B2 + tl.arange(0, B2)
    off_j = block_id_j * B0 + tl.arange(0, B0)
    off_k = block_id_k * B1 + tl.arange(0, B1)

    mask_i = off_i < N2
    mask_j = off_j < N0
    mask_k = off_k < N1

    z = tl.zeros((B2, B0, B1), dtype=tl.float32)
    off_z = off_i[:,None,None] * N0 * N1 & off_j[None,:,None] * N1 & off_k[None,None,:]
    mask_z = mask_i[:,None,None] & mask_j[None,:,None] & mask_k[None,None,:]

    for i in range(0, MID, B_MID):
        off_b = i + tl.arange(0, B_MID)
        mask_b = off_b < MID
        off_x = off_i[:,None,None] * N0 * MID + off_j[None,:,None] * MID + off_b[None,None,:]
        off_y = off_i[:,None,None] * MID * N1 + off_b[None,:,None] * N1 + off_k[None,None,:]
        mask_x = mask_i[:,None,None] & mask_j[None,:,None] & mask_b[None,None,:]
        mask_y = mask_i[:,None,None] & mask_b[None,:,None] & mask_k[None,None,:]

        x = tl.load(x_ptr + off_x, mask=mask_x)
        y = tl.load(y_ptr + off_y, mask=mask_y)
        z = tl.dot(x,y, allow_tf32=False)
        tl.store(z_ptr + off_z, z, mask=mask_z)

    return

Puzzle 12: Quantized Matrix Mult

When doing matrix multiplication with quantized neural networks a common strategy is to store the weight matrix in lower precision, with a shift and scale term.

For this problem our weight will be stored in 4 bits. We can store FPINT of these in a 32 bit integer. In addition for every group weights in order we will store 1 scale float value and 1 shift 4 bit value. We store these for the column of weight. The activations are stored separately in standard floats.

Mathematically it looks like.

.. math::
z_{j, k} = \sum_{l} sc_{j, \frac{l}{g}} (w_{j, l} - sh_{j, \frac{l}{g}}) \times y_{l, k}
\text{ for } j = 1\ldots N_0, k = 1\ldots N_1

Where g is the number of groups (GROUP).

However, it is a bit more complex since we need to also extract the 4-bit values into floats to begin.

Note:

  • We don't consider batch size, i.e. i, in this puzzle.
  • Remember to unpack the FPINT values into separate 4-bit values. This contains some shape manipulation.
    """
FPINT = 32 // 4
GROUP = 8

def quant_dot_spec(
 scale: Float32[32, 8],
 offset: Int32[32,],
 weight: Int32[32, 8],
 activation: Float32[64, 32],
) -> Float32[32, 32]:
 offset = offset.view(32, 1)
def extract(x):
    over = torch.arange(8) * 4
    mask = 2**4 - 1
    return (x[..., None] >> over) & mask

scale = scale[..., None].expand(-1, 8, GROUP).contiguous().view(-1, 64)
offset = (
    extract(offset)[..., None].expand(-1, 1, 8, GROUP).contiguous().view(-1, 64)
)
return (scale * (extract(weight).view(-1, 64) - offset)) @ activation



@triton.jit
def quant_dot_kernel(
    scale_ptr,
    offset_ptr,
    weight_ptr,
    activation_ptr,
    z_ptr,
    N0,
    N1,
    MID,
    B0: tl.constexpr,
    B1: tl.constexpr,
    B_MID: tl.constexpr,
):
    # Finish me!
    block_id_j = tl.program_id(0)
    block_id_k = tl.program_id(1)
    # Finish me!
    off_j = block_id_j * B0 + tl.arange(0, B0)
    off_k = block_id_k * B1 + tl.arange(0, B1)

    mask_j = off_j < N0
    mask_k = off_k < N1

    z = tl.zeros((B0, B1), dtype=tl.float32)
    off_z = off_j[:, None] * N1 + off_k[None, :]
    mask_z = mask_j[:, None] & mask_k[None, :]

    for l in tl.range(0, MID, B_MID):
        # load scale
        off_l_div_g = tl.arange(0, B_MID // GROUP) + (l // GROUP)
        mask_l_div_g = off_l_div_g < (MID // GROUP)
        off_scale = off_j[:, None] * (MID // GROUP) + off_l_div_g[None, :]
        # print(off_scale.shape)
        mask_scale = mask_j[:, None] & mask_l_div_g[None, :]
        scale = tl.load(scale_ptr + off_scale, mask=mask_scale)

        # load shift (offset)
        # (32,), each 32bits integer store FPINT(8)*4 shifts
        shift = tl.load(offset_ptr + off_j, mask=mask_j)

        # load weight
        # note: our weight will be stored in 4bits.
        off_weight_l = l + tl.arange(0, B_MID // FPINT)
        mask_weight_l = off_weight_l < (MID // FPINT)
        off_weight = off_j[:, None] * (MID // FPINT) + off_weight_l[None, :]
        mask_weight = mask_j[:, None] & mask_weight_l[None, :]
        weight = tl.load(weight_ptr + off_weight, mask=mask_weight)

        # load activation as normal float
        off_l = l + tl.arange(0, B_MID)
        mask_l = off_l < MID
        off_activation = off_l[:, None] * N1 + off_k[None, :]
        mask_activation = mask_l[:, None] & mask_k[None, :]
        activation = tl.load(activation_ptr + off_activation, mask=mask_activation)

        # unpack weight and shift
        BITS = 32 // FPINT
        unpack_offs = tl.arange(0, FPINT) * BITS
        unpack_upperbound_mask = (1 << BITS) - 1
        unpacked_shift = (shift[:, None] >> unpack_offs) & unpack_upperbound_mask
        unpacked_weight = (weight[:, :, None] >> unpack_offs) & unpack_upperbound_mask
        # quant transform
        # [BLOCK_J, 8, 1] * ([BLOCK_J, 8, 8] - [BLOCK_J, 8, 1])
        transformed_weight = scale[:, :, None] * (
            unpacked_weight - unpacked_shift[:, :, None]
        )
        # shape: [*, 64]
        transformed_weight = transformed_weight.reshape(
            unpacked_shift.shape[0], unpacked_shift.shape[-1] * FPINT
        )

        # compute
        z += tl.dot(transformed_weight, activation)

    tl.store(z_ptr + off_z, z, mask=mask_z)
    return

Reference

  1. https://zhuanlan.zhihu.com/p/672086654
  2. https://gitcode.com/gh_mirrors/tr/Triton-Puzzles-Lite

    本文由博客一文多发平台 OpenWrite 发布!

jay_kay
0 声望0 粉丝

永远不要忘记,我依然爱着你