Manipulating Data

Loading and Storing Data

With block pointers ready we can start loading data from global memory to much faster shared memory. Triton provides the tl.load and tl.store functions for this. Since we use block pointers we can ignore most arguments in both loading and storing, but the docs do give us some information:

pointer could be a block pointer defined by make_block_ptr, in which case:

  • mask and other must be None
  • boundary_check and padding_option can be specified to control the behavior of out-of-bound access

Ignoring the mask argument, we are left with boundary_check and padding_option. If boundary_check is enabled, out-of-bound memory can be set to a static value using padding_option.

In most cases we are dealing with some multiple of 8 or power of 2 that is happily divisible by whatever block size we use. If there is a mismatch between the tensor length and the block size (e.g. N % BLOCK_N != 0) we will need to have boundary_check enabled. This padding is not very versatile though, since padding options are only zero and nan. For the purpose of our sum kernel the zero works fine, but we will need to use some additional masking for a iterative softmax kernel.

Tensors

We need to keep track of an accumulator variable that stores the intermediate sum. It will essentially be a scalar per program but we have to initiate it as a Triton tensor tl.Tensor. As in Torch, Triton has a variety of ways to start a tensor. We could use tl.full(shape=(1, ), value=0, dtype=tf.float32), which would be the same as tl.zeros(shape=(1, ), dtype=tf.float32). What is different is that the data type is not optional, you have to set it.

What you might often see is that regardless of the precision that comes in (if \(A\) is torch.bfloat or torch.float16), accumulation will be done in tl.float32 to achieve the highest precision available. There is not much downside here since the accumulation is typically small in shape and the data requires is already loaded, so no memory bandwidth wasted here.

The Final Iterative Sum Kernel

We can now load and store data and we have a way to keep track of accumulative sums of the blocks of a row. In case N is not divible by BLOCK_N we can load using a boundary_check. We can add this check in any case and see later what the cost of this check is during the optimization section. The final kernel is then as follows:

import torch
import triton
import triton.language as tl


def sum_row_blocked_iterative(A: torch.Tensor) -> torch.Tensor:
    """Calculate the sum of a tensor A along the final dim.

    Args:
        A: Tensor of shape (M, N) containing the input values.

    Returns:
        Tensor of shape (M, ) containing the summed values.
    """
    M, N = A.shape
    outputs = torch.empty((M,), dtype=A.dtype, device=A.device)

    sum_row_blocked_iterative_kernel[(M, )](
        A_ptr=A, outputs_ptr=outputs,
        M=M, N=N,
        A_strides_x=A.stride(0), A_strides_y=A.stride(1),
        BLOCK_N=8,
    )

    return outputs


@triton.jit
def sum_row_blocked_iterative_kernel(
    A_ptr: tl.tensor, outputs_ptr: tl.tensor,
    M: tl.constexpr, N: tl.constexpr,
    BLOCK_N: tl.constexpr,
    A_strides_x, A_strides_y,
):
    """Calculate the sum of a row of the input tensor, storing the result in
    the output. We assume the input row fits into SRAM.

    Args:
        A_ptr: Pointer to the input tensor.
        outputs_ptr: Pointer to the output tensor.
        M: Number of rows in the input tensor.
        N: Number of columns in the input tensor.
        BLOCK_N: Block size of each row we load.
        input_stride_x: Stride of the input tensor along the row dim.
        input_stride_y: Stride of the input tensor along the column dim.
    """
    program_id = tl.program_id(axis=0)

    input_block_ptr = tl.make_block_ptr(
        base=A_ptr,
        shape=(M, N),
        strides=(A_strides_x, A_strides_y),
        offsets=(program_id, 0),
        block_shape=(1, BLOCK_N),
        order=(1, 0),
    )
    output_block_ptr = tl.make_block_ptr(
        base=outputs_ptr,
        shape=(M, ),
        strides=(1, ),
        offsets=(program_id, ),
        block_shape=(1, ),
        order=(0, ),
    )

    accumulator = tl.zeros((1, ), dtype=tl.float32)
    for _ in range(0, N, BLOCK_N):
        input_block = tl.load(input_block_ptr, boundary_check=(0, 1))
        accumulator += tl.sum(input_block, axis=1)
        input_block_ptr = tl.advance(input_block_ptr, (0, BLOCK_N))

    tl.store(output_block_ptr, accumulator)

I hope this was a useful primer on Triton so far. The next chapter will delve into benchmarking and optimizing kernels. For complete code to the sum kernels you can check the code folder, there will be three versions: simple row sum, blocked row sum and iterative row sum, each discussed in some parts in this chapter.