Getting Started
Before we can start programming some elementary kernels, we need to create a foundation of knowledge on how Triton works. In the following chapters we will discuss how to launch a kernel in Python, how to create a kernel, and commonly used operations inside of kernels.
Blocked Algorithms
Triton is designed to implement what are called blocked algorithms. Blocking (or tiling) can drastically increase locality of reference for a variety of problems. It is the de-facto way to perform fast matrix multiplications in CUDA, and it is also an easy way to perform elementwise operations on a large set of values. Additionally, some other linear algebra calculations such as LU Factorization, SVD(!), and Cholesky Factorization can also be expressed as blocked algorithms, quite nice!
As mentioned, blocked algorithms can also be implemented in CUDA. So how does Triton do it different? It's precisely because Triton is built for blocked algorithms that allow the user to bypass a lot of memory and warp level optimizations or make it trivial to do so by changing key-word arguments during compilation of kernels.
The Launch Grid
A Triton kernel will launch a number of programs to distribute the work over blocks of data. The number of programs to run is a variable we can change, it depends on both the hardware present and the complexity of the algorithm. We can control the number of programs in the launch grid.
As an example, lets try to calculate the sum of rows of an \(6 \times 4\) matrix \(A\). A possible kernel here would be one that launches as many programs as there are rows, and lets each program essentially perform a vector sum. Since the launch grid is a tuple in Python, it would correspond to (6, )
. This will launch 6 distinct programs, each with a row of the data. Each program is denoted a program identifier (PID for short) that is accessible inside the kernel with triton.language.program_id()
. A visualization of this setup can be seen below:
The matrix \(A\) is on the left, Triton launches 6 programs each that load a row and store the sum in the output vector. It's not much effort to define the launch grid in python as you can see in the code snippet below:
import torch
def sum_row(A: torch.Tensor) -> torch.Tensor:
"""Calculate the sum of a tensor A along the final dim.
Args:
A: Tensor of shape (M, N) containing the input values.
Returns:
Tensor of shape (M, ) containing the summed values.
"""
M, N = A.shape
outputs = torch.empty((M,), dtype=A.dtype, device=A.device)
launch_grid = (M, )
sum_row_kernel[launch_grid](...)
return outputs
For now, assume the kernel sum_kernel
is a valid Triton kernel. A valid triton kernel is called with the funky kernel[launch_grid]()
syntax to denote which version of the kernel you want to launch. For now, think of it as a python dictionary where keys are different launch grid configurations and the values are the compiled kernels related to configuration.
We can also divide the work into sets of rows and columns. If we keep the number of programs equal to 6, each program can also process two half rows. This will require a multidimensional launch grid (2, 3)
:
import torch
def sum_row(inputs: torch.Tensor) -> torch.Tensor:
"""Calculate the sum of a tensor along the final dim.
Args:
inputs: Tensor of shape (M, N) containing the input values.
Returns:
Tensor of shape (M, ) containing the summed values.
"""
M, N = inputs.shape
outputs = torch.empty((M,), dtype=inputs.dtype, device=inputs.device)
launch_grid = (M // 3, N // 2)
sum_row_kernel[launch_grid](
input_ptr=inputs, output_ptr=outputs,
M=M, N=N,
input_stride_x=inputs.stride(0), input_stride_y=inputs.stride(1),
)
return outputs
And again for a higher overview we can look at the figure below. Notice the topological layout.
Since we have a two dimensional launch grid, we have programs that have corresponding \(x\) and \(y\) identfiers. To identify the current program working we would have to get both identifiers: pid_x = triton.language.program_id(axis=0)
and pid_y = triton.language.program_id(axis=1)
.
The change to 2D can have an effect on performance since we are no longer loading blocks of contiguous memory. Multidimensional launch grids are, however, not very common, or at least not from what I've seen. In the exercises we will stick to 1D grids.
We will revisit the launch grid later in the optimization section, where we use the fact that the launch grid can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
triton.jit
With the launch grid defined, we can finally start working on our sum kernel.
The first step towards this is writing a function decorated using the Triton just-in-time compilaton decorator, @triton.jit
.
A function that has the decorator can make use of the triton domain specific language inside of it, but will have some limitations:
This function will be compiled and run on the GPU. It will only have access to:
- python primitives,
- builtins within the triton package,
- arguments to this function,
- other jit’d functions
Let's do just that and run it:
import triton
@triton.jit
def do_nothing():
pass
do_nothing[(1, )]()
>>> def do_nothing(, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
^
SyntaxError: invalid syntax
Interesting! It's not totally unexpected that we get an error because the kernel needs input arguments to work. But this reveals a lot of arguments that get added after the jitting. I've briefly documented the parameters in the table below.
Arg name | Arg description |
---|---|
grid | |
num_warps | A warp is a set of 32 threads. How many warps should be ran for this kernel? |
num_stages | |
extern_libs | |
stream | |
warmup | |
device | |
device_type |
Table is Work In Progress
But we are here trying to write a sum-row kernel imitating A.sum(axis=1)
! So the first thing we do is add \(A\) as an input to the kernel. The result is that A is turned into a pointer towards its first element. Everything related to data loading and storing is done through pointers, so its good to get comfortable with some minor pointer arithmetic. We also have to add the pre-defined output vector outputs
.
import triton
import torch
def sum_row(A: torch.Tensor) -> torch.Tensor:
"""Calculate the sum of a tensor A along the final dim.
Args:
A: Tensor of shape (M, N) containing the input values.
Returns:
Tensor of shape (M, ) containing the summed values.
"""
M, N = A.shape
outputs = torch.empty((M,), dtype=A.dtype, device=A.device)
launch_grid = (M, )
sum_row_kernel[launch_grid](A, outputs)
return outputs
@triton.jit
def sum_row_kernel(A_ptr, outputs_ptr):
# A_ptr is now a pointer towards its first element, similar for outputs_ptr.
We now have a kernel that will run \(M\) different programs each with a pointer towards the first element in \(A\). What we don't have is some way to distinguish these programs to access different points in the data.
Program Identifiers
We mentioned this before, but there is a very simple way to identify the current program at work inside Triton: triton.language.program_id
.
With that out of the way, let's start thinking about how we are going to load a single row of matrix \(A\). If we take \(A\) to be a matrix of size \(6 \times 4\) we will run \(6\) programs in total:
import triton
import torch
def sum_row(A: torch.Tensor) -> torch.Tensor:
"""Calculate the sum of a tensor A along the final dim.
Args:
A: Tensor of shape (M, N) containing the input values.
Returns:
Tensor of shape (M, ) containing the summed values.
"""
M, N = A.shape
outputs = torch.empty((M,), dtype=A.dtype, device=A.device)
launch_grid = (M, )
sum_row_kernel[launch_grid](A, outputs)
return outputs
@triton.jit
def sum_row_kernel(A_ptr, outputs_ptr):
program_id = tl.program_id(axis=0)
How do we figure out the correct offset for this program_id
to load the data? Let's take program_id = 5
as an example. Program 5 should load the last row in memory. Setting offset = A_ptr + program_id
will start our data loading at the \(5\)th element which is actually the first element of the second row The figure below demonstrates the issue:
In the simple case, we could set offset = A_ptr + N * program_id
.
This will work if our input \(A\) has usual strides, but we can't guarantee on that1.
In general, we can't guarantee inside the kernel that any of the strides are regular, so even for a step into the \(x\) direction we will need to know the strides. For vectors this is a different story, as there torch always has a stride of 1 there.
We can add the strides to the kernel, and with that done, the offset can be calculated as offset = A_ptr + A_stride_y * program_id
. That only gives us the pointer to the correct row, though. We need pointers to the entire row block. So we will actually need to know the length of the row too, another argument we add to the kernel. Getting a block of pointers is straightforward, however, since we can use triton.language.arange(0, N)
. the triton.language.arange
function is very similar to its NumPy and Torch counterparts. Now we are loading the right block!
With that done, let's add the additional input arguments to ensure we load the right row from \(A\). The updated code can be seen below:
import triton
import triton.language as tl
import torch
def sum_row(A: torch.Tensor) -> torch.Tensor:
"""Calculate the sum of a tensor A along the final dim.
Args:
A: Tensor of shape (M, N) containing the input values.
Returns:
Tensor of shape (M, ) containing the summed values.
"""
M, N = A.shape
outputs = torch.empty((M,), dtype=A.dtype, device=A.device)
launch_grid = (M, )
sum_row_kernel[launch_grid](
A, outputs,
N=N,
A_strides_x=A.strides(0), A_strides_y=A.strides(1),
)
return outputs
@triton.jit
def sum_row_kernel(
A_ptr, outputs_ptr,
N,
A_strides_x, A_strides_y,
):
program_id = tl.program_id(axis=0)
offsets = tl.arange(0, N) + A_ptr + program_id * A_stride_y
Pointer arithmetic is not always obvious though, and there is a better albeit experimental way to create block pointers that will now discuss.
For a good introduction on strides in Torch see https://zhang-yang.medium.com/explain-pytorch-tensor-stride-and-tensor-storage-with-code-examples-50e637f1076d
Block Pointers
Pointer arithmetic can be tedious work and it's easy to mess up. Not to mention that we have not worked with loading blocks of 2D data, introducing multidimensional pointer blocks. From the official Triton tutorial on matrix multiplications, the following snippet shows how this 2D arithmetic works:
# Program ID
pid = tl.program_id(axis=0)
# Number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# Number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# Number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# Id of the group this program is in
group_id = pid // num_pid_in_group
# Row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *Within groups*, programs are ordered in a column-major order
# Row-id of the program in the *launch grid*
pid_m = first_pid_m + (pid % group_size_m)
# Col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m
Let's steer clear of all that, and start using the block pointer functionality that is still an experimental feature. It changes the setup from this:
import triton
import triton.language as tl
import torch
def sum_row(A: torch.Tensor) -> torch.Tensor:
"""Calculate the sum of a tensor A along the final dim.
Args:
A: Tensor of shape (M, N) containing the input values.
Returns:
Tensor of shape (M, ) containing the summed values.
"""
M, N = A.shape
outputs = torch.empty((M,), dtype=A.dtype, device=A.device)
launch_grid = (M, )
sum_kernel[launch_grid](
A, outputs,
N=N,
A_strides_x=A.strides(0), A_strides_y=A.strides(1),
)
return outputs
@triton.jit
def sum_kernel(
A_ptr, outputs_ptr,
N,
A_strides_x, A_strides_y,
):
program_id = tl.program_id(axis=0)
offsets = tl.arange(0, N) + A_ptr + program_id * A_stride_y
To this:
import triton
import triton.language as tl
import torch
def sum_row(A: torch.Tensor) -> torch.Tensor:
"""Calculate the sum of a tensor A along the final dim.
Args:
A: Tensor of shape (M, N) containing the input values.
Returns:
Tensor of shape (M, ) containing the summed values.
"""
M, N = A.shape
outputs = torch.empty((M,), dtype=A.dtype, device=A.device)
launch_grid = (M, )
sum_row_kernel[launch_grid](
A, outputs,
M=M, N=N,
A_strides_x=A.stride(0), A_strides_y=A.stride(1),
)
return outputs
@triton.jit
def sum_row_kernel(
A_ptr, outputs_ptr,
M, N,
input_stride_x, input_stride_y,
):
program_id = tl.program_id(axis=0)
offsets = tl.make_block_ptr(
base=A_ptr,
shape=(M, N),
strides=(input_stride_x, A_stride_y),
offsets=(program_id, 0),
block_shape=(1, N),
order=(1, 0),
)
A little bit more work and more added arguments, but this allows us to load 1D and 2D blocks with ease, and we can also skip any masking for out-of-bounds memory access. The table below gives a brief description per argument1.
abc | def |
---|---|
base | The data pointer from which you want to load a block |
shape | Shape of the base tensor |
strides | Strides of the base tensor |
offsets | From what location do you want to start loading data |
block_shape | What is the shape of the data block to load |
order | The memory layout of the base tensor |
Block Pointers in 2D and Dynamic Launch Grids
We mentioned earlier that block pointers make 2D loading easier too. As an example, let's transform the block pointer to load not one row, but 2 or potentially more as the following figure indicates:
This has consequences for the launch grid, though. We would essentially need half as much programs to be launched if we load 2 rows each time. But what if we load 4 rows each? That would reduce the number of programs by half again. Instead of statically changing the launch grid each time, we can make it dynamic.
The launch grid is not defined only to be a tuple of integers, it can also be a callable that returns a tuple of integers. This callable has as input the parameters of the kernel so we can dynamically select the number of programs to be launched as a function of the number of rows we process:
import triton
import triton.language as tl
import torch
def sum_row_blocked(A: torch.Tensor) -> torch.Tensor:
"""Calculate the sum of a tensor A along the final dim.
Args:
A: Tensor of shape (M, N) containing the input values.
Returns:
Tensor of shape (M, ) containing the summed values.
"""
M, N = A.shape
outputs = torch.empty((M,), dtype=A.dtype, device=A.device)
dynamic_launch_grid = lambda params: (triton.cdiv(M, params["BLOCK_M"]), )
sum_row_blocked_kernel[dynamic_launch_grid](
A_ptr=A, outputs_ptr=outputs,
M=M, N=N,
A_strides_x=A.stride(0), A_strides_y=A.stride(1),
BLOCK_M=2,
)
return outputs
@triton.jit
def sum_row_blocked_kernel(
A_ptr, outputs_ptr,
M, N,
BLOCK_M,
A_strides_x, A_strides_y,
):
"""Calculate the sum of a row of the input tensor, storing the result in
the output. We assume the input row fits into SRAM.
Args:
A_ptr: Pointer to the input tensor.
outputs_ptr: Pointer to the output tensor.
M: Number of rows in the input tensor.
N: Number of columns in the input tensor.
input_stride_x: Stride of the input tensor along the row dim.
input_stride_y: Stride of the input tensor along the column dim.
"""
program_id = tl.program_id(axis=0)
input_block_ptr = tl.make_block_ptr(
base=A_ptr,
shape=(M, N),
strides=(A_strides_x, A_strides_y),
offsets=(program_id * BLOCK_M, 0),
block_shape=(BLOCK_M, N),
order=(1, 0),
)
output_block_ptr = tl.make_block_ptr(
base=outputs_ptr,
shape=(M, ),
strides=(1, ),
offsets=(program_id * BLOCK_M, ),
block_shape=(BLOCK_M, ),
order=(0, ),
)
input_block = tl.load(input_block_ptr)
tl.store(output_block_ptr, tl.sum(input_block, axis=1))
It's impressive how little code we had to change to switch from 1D to 2D, so block pointers are definitely my go-to for getting data offsets.
Advancing Block Pointers
In most situations we can easily load the whole row into memory and process a row or even a set of rows per program. But imagine we are not capable of loading the entire row into memory - it's too big for our cache! What can do, is iterate over the row in blocks.
the iterative part is where tl.advance
comes into play.
Each program will load a block of size BLOCK_N
<< N
and iterate untill it has seen the full row.
import torch
import triton
import triton.language as tl
def sum_row_blocked_iterative(A: torch.Tensor) -> torch.Tensor:
"""Calculate the sum of a tensor A along the final dim.
Args:
A: Tensor of shape (M, N) containing the input values.
Returns:
Tensor of shape (M, ) containing the summed values.
"""
M, N = A.shape
outputs = torch.empty((M,), dtype=A.dtype, device=A.device)
sum_row_blocked_iterative_kernel[(M, )](
A_ptr=A, outputs_ptr=outputs,
M=M, N=N,
A_strides_x=A.stride(0), A_strides_y=A.stride(1),
BLOCK_N=8,
)
return outputs
@triton.jit
def sum_row_blocked_iterative_kernel(
A_ptr, outputs_pt,
M, N,
BLOCK_N,
A_strides_x, A_strides_y,
):
"""Calculate the sum of a row of the input tensor, storing the result in
the output. We assume the input row fits into SRAM.
Args:
A_ptr: Pointer to the input tensor.
outputs_ptr: Pointer to the output tensor.
M: Number of rows in the input tensor.
N: Number of columns in the input tensor.
input_stride_x: Stride of the input tensor along the row dim.
input_stride_y: Stride of the input tensor along the column dim.
"""
program_id = tl.program_id(axis=0)
input_block_ptr = tl.make_block_ptr(
base=A_ptr,
shape=(M, N),
strides=(A_strides_x, A_strides_y),
offsets=(program_id, 0),
block_shape=(1, BLOCK_N),
order=(1, 0),
)
output_block_ptr = tl.make_block_ptr(
base=outputs_ptr,
shape=(M, ),
strides=(1, ),
offsets=(program_id, ),
block_shape=(1, ),
order=(0, ),
)
accumulator = tl.zeros((1, ), dtype=tl.float32)
for _ in range(0, N, BLOCK_N):
input_block = tl.load(input_block_ptr)
accumulator += tl.sum(input_block, axis=1)
input_block_ptr = tl.advance(input_block_ptr, (0, BLOCK_N))
tl.store(output_block_ptr, accumulator)
There are some consequences in terms of out-of-bounds memory access checking, but we will cover this is the next section.
There is more information available in the official Triton docs for the Blocked Pointer Matrix Multiplication tutorial.
Manipulating Data
Loading and Storing Data
With block pointers ready we can start loading data from global memory to much faster shared memory. Triton provides the tl.load
and tl.store
functions for this. Since we use block pointers we can ignore most arguments in both loading and storing, but the docs do give us some information:
pointer could be a block pointer defined by make_block_ptr, in which case:
- mask and other must be None
- boundary_check and padding_option can be specified to control the behavior of out-of-bound access
Ignoring the mask argument, we are left with boundary_check
and padding_option
. If boundary_check
is enabled, out-of-bound memory can be set to a static value using padding_option
.
In most cases we are dealing with some multiple of 8 or power of 2 that is happily divisible by whatever block size we use.
If there is a mismatch between the tensor length and the block size (e.g. N % BLOCK_N != 0
) we will need to have boundary_check
enabled.
This padding is not very versatile though, since padding options are only zero
and nan
. For the purpose of our sum kernel the zero
works fine, but we will need to use some additional masking for a iterative softmax kernel.
Tensors
We need to keep track of an accumulator variable that stores the intermediate sum.
It will essentially be a scalar per program but we have to initiate it as a Triton tensor tl.Tensor
.
As in Torch, Triton has a variety of ways to start a tensor. We could use tl.full(shape=(1, ), value=0, dtype=tf.float32)
, which would be the same as tl.zeros(shape=(1, ), dtype=tf.float32)
. What is different is that the data type is not optional, you have to set it.
What you might often see is that regardless of the precision that comes in (if \(A\) is torch.bfloat
or torch.float16
), accumulation will be done in tl.float32
to achieve the highest precision available.
There is not much downside here since the accumulation is typically small in shape and the data requires is already loaded, so no memory bandwidth wasted here.
The Final Iterative Sum Kernel
We can now load and store data and we have a way to keep track of accumulative sums of the blocks of a row. In case N
is not divible by BLOCK_N
we can load using a boundary_check
. We can add this check in any case and see later what the cost of this check is during the optimization section. The final kernel is then as follows:
import torch
import triton
import triton.language as tl
def sum_row_blocked_iterative(A: torch.Tensor) -> torch.Tensor:
"""Calculate the sum of a tensor A along the final dim.
Args:
A: Tensor of shape (M, N) containing the input values.
Returns:
Tensor of shape (M, ) containing the summed values.
"""
M, N = A.shape
outputs = torch.empty((M,), dtype=A.dtype, device=A.device)
sum_row_blocked_iterative_kernel[(M, )](
A_ptr=A, outputs_ptr=outputs,
M=M, N=N,
A_strides_x=A.stride(0), A_strides_y=A.stride(1),
BLOCK_N=8,
)
return outputs
@triton.jit
def sum_row_blocked_iterative_kernel(
A_ptr: tl.tensor, outputs_ptr: tl.tensor,
M: tl.constexpr, N: tl.constexpr,
BLOCK_N: tl.constexpr,
A_strides_x, A_strides_y,
):
"""Calculate the sum of a row of the input tensor, storing the result in
the output. We assume the input row fits into SRAM.
Args:
A_ptr: Pointer to the input tensor.
outputs_ptr: Pointer to the output tensor.
M: Number of rows in the input tensor.
N: Number of columns in the input tensor.
BLOCK_N: Block size of each row we load.
input_stride_x: Stride of the input tensor along the row dim.
input_stride_y: Stride of the input tensor along the column dim.
"""
program_id = tl.program_id(axis=0)
input_block_ptr = tl.make_block_ptr(
base=A_ptr,
shape=(M, N),
strides=(A_strides_x, A_strides_y),
offsets=(program_id, 0),
block_shape=(1, BLOCK_N),
order=(1, 0),
)
output_block_ptr = tl.make_block_ptr(
base=outputs_ptr,
shape=(M, ),
strides=(1, ),
offsets=(program_id, ),
block_shape=(1, ),
order=(0, ),
)
accumulator = tl.zeros((1, ), dtype=tl.float32)
for _ in range(0, N, BLOCK_N):
input_block = tl.load(input_block_ptr, boundary_check=(0, 1))
accumulator += tl.sum(input_block, axis=1)
input_block_ptr = tl.advance(input_block_ptr, (0, BLOCK_N))
tl.store(output_block_ptr, accumulator)
I hope this was a useful primer on Triton so far. The next chapter will delve into benchmarking and optimizing kernels. For complete code to the sum kernels you can check the code folder, there will be three versions: simple row sum, blocked row sum and iterative row sum, each discussed in some parts in this chapter.