当前位置: 首页 > news >正文

Triton 入门:从编程思想到 SM、Warp、Register、SMEM、Program 与 Occupancy

本文结合chatgpt生成

1. 为什么需要 Triton?

在深度学习系统里,我们经常会遇到这样的场景:

  • PyTorch 算子能完成计算,但性能不够理想;
  • CUDA C++ 性能很强,但开发成本较高;
  • 某些操作不是标准 GEMM、Conv、Attention,而是比较定制化的融合算子;
  • 希望把多个 PyTorch operation 融合成一个 GPU kernel,减少中间 tensor 和显存读写。

例如,下面这个 softmax:

x_max = x.max(dim=1, keepdim=True).values
z = torch.exp(x - x_max)
y = z / z.sum(dim=1, keepdim=True)

从数学上看很简单,但如果直接用 PyTorch 写,它可能会被拆成多个 kernel:

  1. 一个 kernel 做 max
  2. 一个 kernel 做 x - x_max
  3. 一个 kernel 做 exp
  4. 一个 kernel 做 sum
  5. 一个 kernel 做除法。

每一步都可能读写 global memory,也就是 GPU 显存。显存带宽虽然很高,但相比寄存器、shared memory 仍然慢很多。

Triton 的目标就是让你用接近 Python 的方式写高性能 GPU kernel。它介于 PyTorch 和 CUDA C++ 之间:

层级 特点
PyTorch 易用,自动调度,但难以控制底层 kernel
Triton 写法接近 Python,但可以控制 block、warp、memory access、kernel fusion
CUDA C++ 控制力最强,但开发复杂度最高

Triton 特别适合写:

  • 向量加法;
  • LayerNorm;
  • Softmax;
  • RMSNorm;
  • 自定义激活函数;
  • fused elementwise kernel;
  • 矩阵乘法;
  • Attention 相关算子;
  • 推理框架中的自定义 kernel。

2. Triton 的核心编程思想

学习 Triton,最重要的是先摆脱一个误区:

Triton 不是让你显式写每个 GPU thread 干什么,而是让你写一个 program 处理一块数据。

在 CUDA 里,常见的抽象是:

grid -> block -> thread

会写:

int tid = blockIdx.x * blockDim.x + threadIdx.x;

然后让每个 thread 处理一个或几个元素。

而在 Triton 中,更常见的抽象是:

grid -> program -> vectorized operations

Triton 里的 program 可以粗略类比 CUDA 的一个 block / CTA,但不应该完全等同。你通常不直接写 thread id,而是写:

pid = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)

这里的 pid 是当前 Triton program 的编号,offsets 是一个向量。

例如:

offsets = tl.arange(0, 8)

可以理解为:

offsets = [0, 1, 2, 3, 4, 5, 6, 7]

然后:

x = tl.load(x_ptr + offsets)

表示一次加载多个元素。

所以 Triton 的基本风格是:

一个 program 负责一块数据,program 内部用向量化表达一组元素的计算。

这和 CUDA 的思维有明显不同。CUDA 更像是“我安排每个 thread 做什么”,Triton 更像是“我安排一个 program 处理一个 tile/block,然后由编译器把这些向量操作映射到底层 GPU 执行”。

3. 一个最小 Triton Kernel:向量加法

先看一个简单例子:计算

\[c_i = a_i + b_i \]

代码如下:

import torch
import triton
import triton.language as tl@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr, n_elements, BLOCK_SIZE: tl.constexpr):pid = tl.program_id(0)block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)mask = offsets < n_elementsa = tl.load(a_ptr + offsets, mask=mask, other=0.0)b = tl.load(b_ptr + offsets, mask=mask, other=0.0)c = a + btl.store(c_ptr + offsets, c, mask=mask)def vector_add(a, b):c = torch.empty_like(a)n_elements = a.numel()BLOCK_SIZE = 1024grid = (triton.cdiv(n_elements, BLOCK_SIZE),)vector_add_kernel[grid](a, b, c, n_elements, BLOCK_SIZE=BLOCK_SIZE)return c

这段代码里面有几个 Triton 最基础的概念:

@triton.jit

表示这是一个 Triton kernel,会被 JIT 编译成 GPU 代码。

pid = tl.program_id(0)

表示当前 program 在 grid 第 0 维上的编号。

offsets = block_start + tl.arange(0, BLOCK_SIZE)

表示当前 program 要处理的一组元素下标。

mask = offsets < n_elements

用于防止越界访问。

tl.load(...)
tl.store(...)

表示从 global memory 读取和写入。

BLOCK_SIZE: tl.constexpr

表示 BLOCK_SIZE 是编译期常量。不同的 BLOCK_SIZE 往往会生成不同版本的 kernel。


4. Triton Kernel 是什么?

在 GPU 编程中,kernel 指的是运行在 GPU 上的函数。

在 PyTorch 里,通常写:

y = x + 1

底层可能会启动一个 GPU kernel 来完成加法。

在 Triton 里,自己定义 kernel:

@triton.jit
def my_kernel(...):...

然后通过下面这种语法启动:

my_kernel[grid](arg1, arg2, ..., META_ARG=value)

例如:

vector_add_kernel[grid](a, b, c, n_elements, BLOCK_SIZE=1024)

这里的:

grid = (num_programs,)

表示启动多少个 Triton programs。

可以把一次 kernel launch 理解为:

CPU 端发起一次 GPU 任务
GPU 上启动 grid 中指定数量的 programs
每个 program 处理一块数据

5. Program 是什么?

Program 是 Triton 最重要的抽象之一。

一个 program 通常负责一个 tile/block 的数据。例如:

  • 向量加法中,一个 program 处理 BLOCK_SIZE 个元素;
  • softmax 中,一个 program 可以处理一整行;
  • matmul 中,一个 program 可以处理输出矩阵 C 的一个 tile,比如 BLOCK_M x BLOCK_N
  • attention 中,一个 program 可以处理 query/block 和 key/value/block 的一部分。

以向量加法为例:

BLOCK_SIZE = 1024
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)

如果 n_elements = 10000,那么:

num_programs = ceil(10000 / 1024) = 10

于是:

program 0 -> 处理元素 0 ~ 1023
program 1 -> 处理元素 1024 ~ 2047
program 2 -> 处理元素 2048 ~ 3071
...

每个 program 通过:

pid = tl.program_id(0)

知道自己是谁。

然后通过:

block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

计算自己负责的数据范围。

这就是 Triton 的基本工作模式。

6. Grid 是什么?

Grid 决定了启动多少个 programs。

例如:

grid = (1024,)

表示一维 grid,启动 1024 个 programs。

也可以是二维或三维:

grid = (M, N)

在 kernel 里可以这样获取 program id:

pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

对于矩阵乘法,经常会用二维 grid:

program_id(0) -> 输出矩阵 C 的 M 方向 tile 编号
program_id(1) -> 输出矩阵 C 的 N 方向 tile 编号

例如 C 的 shape 是 [M, N],每个 program 计算一个 BLOCK_M x BLOCK_N 的 tile:

grid_m = ceil(M / BLOCK_M)
grid_n = ceil(N / BLOCK_N)grid = (grid_m, grid_n)

那么:

program (0, 0) -> C 左上角 tile
program (0, 1) -> C 第一行第二个 tile
program (1, 0) -> C 第二行第一个 tile
...

所以 grid 的作用就是告诉 GPU:

这次 kernel launch 一共要启动多少个 Triton programs,以及这些 programs 如何组织。

7. tl.arange 与向量化编程

Triton 的一个核心特征是向量化。

在 CUDA 里,可能会写:

int idx = blockIdx.x * blockDim.x + threadIdx.x;
c[idx] = a[idx] + b[idx];

而在 Triton 里,写:

offsets = block_start + tl.arange(0, BLOCK_SIZE)
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
c = a + b
tl.store(c_ptr + offsets, c, mask=mask)

这里的 abc 都不是单个 scalar,而是一组值。

例如 BLOCK_SIZE = 8,那么:

offsets = [0, 1, 2, 3, 4, 5, 6, 7]

于是:

a = tl.load(a_ptr + offsets)

相当于加载:

a[0], a[1], a[2], ..., a[7]

这就是 Triton 的核心编程思想:

写起来像是在操作一个向量,编译器负责把它映射到底层 GPU 的并行执行。

8. Mask 为什么重要?

Triton kernel 经常会使用 block/tile 的方式处理数据。但数据总长度不一定刚好被 BLOCK_SIZE 整除。

例如:

n_elements = 10000
BLOCK_SIZE = 1024

最后一个 program 的 offsets 可能是:

9216, 9217, ..., 10239

但合法元素只有:

0 ~ 9999

所以最后 240 个 offsets 是越界的。

因此需要:

mask = offsets < n_elements

然后:

tl.load(a_ptr + offsets, mask=mask, other=0.0)
tl.store(c_ptr + offsets, c, mask=mask)

含义是:

  • mask=True 的位置正常读写;
  • mask=False 的位置不读写;
  • load 时如果 mask 为 False,就用 other 作为默认值。

在 softmax 里,越界位置通常填:

other=-float("inf")

因为 softmax 里会做:

exp(x)

而:

\[e^{-\infty} = 0 \]

不会影响求和结果。

9. tl.constexpr 是什么?

在 Triton 中,经常会看到:

BLOCK_SIZE: tl.constexpr

这表示 BLOCK_SIZE 是编译期常量。

例如:

vector_add_kernel[grid](..., BLOCK_SIZE=1024)

Triton 会针对 BLOCK_SIZE=1024 编译一个版本的 kernel。

如果之后用:

BLOCK_SIZE=2048

Triton 可能会重新编译另一个版本。

为什么要这样?

因为 GPU kernel 的性能高度依赖静态 shape 和编译期优化。比如:

offsets = tl.arange(0, BLOCK_SIZE)

如果 BLOCK_SIZE 是编译期已知的,编译器就能更好地展开、优化、分配寄存器和生成底层代码。

常见的 tl.constexpr 参数包括:

BLOCK_SIZE
BLOCK_M
BLOCK_N
BLOCK_K
num_warps
num_stages

这些参数不是普通运行时数据,而是会影响 kernel 编译结果的 meta-parameters。

10. SM 是什么?

SM,全称 Streaming Multiprocessor,是 NVIDIA GPU 的核心计算单元。

可以粗略理解为:

GPU 由很多个 SM 组成,每个 SM 里面有 warp scheduler、寄存器文件、shared memory、CUDA cores / Tensor Cores 等资源。

例如一张 GPU 可能有:

80 个 SM

当你启动一个 Triton kernel 时,GPU 会把 programs 分配到不同 SM 上执行。

如果启动了很多 programs,那么这些 programs 会被调度到各个 SM 上:

SM 0 -> program 0, program 1, ...
SM 1 -> program 2, program 3, ...
SM 2 -> program 4, program 5, ...
...

实际调度由 GPU 硬件和驱动决定。

理解 SM 很重要,因为很多性能优化问题本质上都是:

如何让所有 SM 尽量忙起来?

如果 programs 太少,很多 SM 没活干,性能会差。

例如 GPU 有 80 个 SM,但你只启动 10 个 programs,那么最多只有 10 个 SM 有工作,剩下的 SM 空闲。

所以 kernel launch 的 grid 通常要有足够多的 programs。

11. Warp 是什么?

Warp 是 GPU 执行调度的基本单位。

在 NVIDIA GPU 上,一个 warp 通常包含 32 个 threads。

1 warp = 32 threads

GPU 不是一个 thread 一个 thread 地单独调度,而是以 warp 为单位调度。

Triton 里你经常会看到:

num_warps=4

或者:

num_warps=8

它表示:

每个 Triton program 使用多少个 warps。

例如:

num_warps = 8

如果 warp size 是 32,那么一个 program 里大概会有:

8 * 32 = 256 threads

参与执行。

不过在 Triton 中,通常不直接操纵这些 threads,而是通过向量化操作表达计算。

num_warps 会影响:

  • 一个 program 内部的并行度;
  • reduction 的执行效率;
  • register 使用;
  • occupancy;
  • kernel 性能。

经验上:

  • 小 block 可以用较少 warps,比如 1、2、4;
  • 大 block 或复杂 reduction 可以用更多 warps,比如 4、8;
  • warp 太少,单个 program 并行度不足;
  • warp 太多,资源占用变高,occupancy 可能下降。

所以 num_warps 是 Triton 性能调优中的重要参数。

12. Register / Reg 是什么?

Register(寄存器) 是 GPU 上最快的存储资源之一。

每个 thread 都会使用一些 registers 存放临时变量。例如:

a = tl.load(...)
b = tl.load(...)
c = a + b

这里的 abc 在底层可能会占用寄存器。

寄存器的优点:

  • 访问速度极快;
  • 适合存放中间变量;
  • 不需要访问 global memory。

但寄存器有一个重要限制:

每个 SM 上的寄存器总量是有限的。

如果一个 kernel 每个 thread 使用很多寄存器,那么一个 SM 上能同时驻留的 programs/warps 就会减少。

例如,假设一个 SM 上有 NUM_REGS 个寄存器,每个 program 使用 num_warps 个 warp,每个 warp 有 WARP_SIZE 个 threads,一个 kernel 编译后每个 thread 使用 n_regs 个寄存器,那么一个 program 大致需要:

\[ num_{warps} \times WARP_{SIZE} \times n_{regs} \]

个寄存器。

于是仅从寄存器角度估算,一个 SM 上最多能放:

\[\left\lfloor \frac{NUM_{REGS}} {num_{warps} \times WARP_{SIZE} \times n_{regs}} \right\rfloor \]

个 programs。

这就是为什么有时候 kernel 中间变量太多,性能会变差。不是因为计算变多,而是因为寄存器压力上升,导致 occupancy 下降。

13. SMEM / Shared Memory 是什么?

SMEM 通常指 shared memory。

Shared memory 是每个 SM 内部的一块高速片上存储。

它比 global memory 快得多,但容量小得多。

粗略层次可以这样理解:

register        最快,容量最小,thread 私有
shared memory   很快,容量较小,program/block 内共享
global memory   较慢,容量最大,全局可见

在 CUDA 中,shared memory 是一个非常重要的概念。在线程块内部,多个 threads 可以通过 shared memory 协作。

在 Triton 中,你不一定总是显式声明 shared memory,但编译器可能在某些场景下使用 shared memory,例如:

  • reduction;
  • dot/matmul;
  • 数据重排;
  • pipeline;
  • 某些 block-level 协作。

Triton 编译后的 kernel metadata 中可能有:

kernel.metadata.shared

表示该 kernel 使用了多少 shared memory。

Shared memory 也会限制 occupancy。假设:

  • 一个 SM 有 SIZE_SMEM 的 shared memory;
  • 每个 program 使用 size_smem 的 shared memory;

那么从 shared memory 角度,一个 SM 最多能放:

\[\left\lfloor \frac{SIZE_{SMEM}} {size_{smem}} \right\rfloor \]

个 programs。

最终 occupancy 会同时受到 register 和 shared memory 的限制。


14. Occupancy 是什么?

Occupancy 可以粗略理解为:

一个 SM 上同时驻留了多少 warps 或 programs。

在 Triton 的语境里,我们经常关心:

每个 SM 上能同时驻留多少个 Triton programs?

影响 occupancy 的主要因素包括:

  • 每个 program 使用多少 warps;
  • GPU 每个 SM 的最大 warps 数;
  • 每个 thread 使用多少 registers;
  • 每个 program 使用多少 shared memory;
  • GPU 每个 SM 的最大 blocks/programs 数;
  • kernel 的资源使用情况。

一个简化估算是:

occupancy_by_regs = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy_by_smem = SIZE_SMEM // size_smem
occupancy = min(occupancy_by_regs, occupancy_by_smem)

这不是完整的真实 occupancy 公式,但足够帮助我们理解 Triton 代码中的调度逻辑。

高 occupancy 的好处是:

当某些 warps 等待 memory load 时,SM 可以切换去执行其他 ready warps,从而隐藏 latency。

但 occupancy 不是越高越好。

例如:

  • 某些 compute-bound kernel,更重要的是 Tensor Core 利用率;
  • 某些 memory-bound kernel,更重要的是访存 coalescing 和数据复用;
  • 某些 kernel 提高 occupancy 可能需要降低 BLOCK_SIZE,反而增加 global memory 访问;
  • 过度追求 occupancy 可能牺牲单个 program 的计算效率。

所以正确理解应该是:

occupancy 是性能的重要指标,但不是唯一指标。它帮助 GPU 隐藏延迟,但不能单独决定性能。

15. Global Memory、Register、SMEM 之间的关系

写 Triton kernel 时,脑子里要有一个存储层次:

Global Memory / HBM|| tl.loadv
Register / SMEM|| computev
Register|| tl.storev
Global Memory / HBM

例如向量加法:

a = tl.load(a_ptr + offsets)
b = tl.load(b_ptr + offsets)
c = a + b
tl.store(c_ptr + offsets, c)

大致过程是:

  1. 从 global memory 读入 a
  2. 从 global memory 读入 b
  3. 在寄存器里计算 c = a + b
  4. c 写回 global memory。

对于 softmax:

row = tl.load(...)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
output = numerator / denominator
tl.store(...)

大致过程是:

  1. 从 global memory 读取一行;
  2. 在 register / program 内部做 max reduction;
  3. 计算 exp;
  4. 做 sum reduction;
  5. 计算除法;
  6. 写回 global memory。

Triton 的优势之一是可以把多个操作融合到一个 kernel 里,使中间结果尽量保存在 register/SMEM 中,而不是写回 global memory。

16. Triton 中的 num_warps

num_warps 是 kernel launch 时常见的 meta-parameter。

例如:

my_kernel[grid](..., BLOCK_SIZE=1024, num_warps=4)

它控制每个 program 使用多少个 warps。

选择 num_warps 时,可以粗略参考:

场景 可能的选择
小规模 elementwise 1、2、4
中等 block reduction 4、8
较大 softmax / reduction 8
matmul tile 通常需要结合 BLOCK_M/N/K autotune

num_warps 增大,可能带来:

  • program 内并行度更高;
  • reduction 更快;
  • memory transaction 更充分;
  • 但 register 占用增加;
  • occupancy 下降;
  • 对小任务可能浪费资源。

所以它需要结合具体 kernel 调。

17. Triton 中的 num_stages

num_stages 通常和软件流水线有关,尤其在 matmul、attention 这类需要循环加载数据的 kernel 中更重要。

例如矩阵乘法中,一个 program 可能不断加载 A 和 B 的 block:

load A block 0, B block 0
compute
load A block 1, B block 1
compute
load A block 2, B block 2
compute
...

如果没有流水线,可能是:

等 load 完 -> compute -> 再 load -> 再 compute

使用软件流水线后,可以尝试让 load 和 compute 重叠:

load 下一块数据的同时,计算当前块

num_stages 表示 pipeline 的阶段数。

更大的 num_stages 可能:

  • 更好地隐藏 memory latency;
  • 提高吞吐;
  • 但会增加 register / shared memory 使用;
  • 降低 occupancy。

所以 num_stages 也不是越大越好。

18. BLOCK_SIZE 是什么?

BLOCK_SIZE 是每个 program 处理的数据块大小。

在向量加法中:

BLOCK_SIZE = 1024

表示每个 program 处理 1024 个元素。

在 softmax 中:

BLOCK_SIZE = triton.next_power_of_2(n_cols)

表示每个 program 处理一整行,并把列数补到 2 的幂。

例如:

n_cols = 1000
BLOCK_SIZE = 1024

这样做的原因是 Triton 的很多向量化和 reduction 操作更适合静态大小,尤其是 2 的幂长度。

BLOCK_SIZE 太大也会有问题。

如果 BLOCK_SIZE 很大,program 内部需要维护很大的向量:

row
row_minus_max
numerator
output

这些中间变量会占用大量 registers,导致:

register pressure 上升
occupancy 下降
spill 风险增加
性能下降

所以选择 BLOCK_SIZE 时需要平衡:

  • 单个 program 处理更多数据,减少 program 数量和调度开销;
  • 但单个 program 占用资源更多,可能降低 occupancy;
  • block 太小,可能访存效率和计算效率不足;
  • block 太大,可能寄存器压力过高。

19. Softmax Kernel 示例:理解 Triton 的完整流程

下面是一个简化版 row-wise softmax kernel。

输入:

x: [n_rows, n_cols]

输出:

y: [n_rows, n_cols]

目标:

\[y_{i,j} = \frac{ e^{x_{i,j} - \max(x_i)} }{ \sum_k e^{x_{i,k} - \max(x_i)} } \]

Triton kernel:

@triton.jit
def softmax_kernel(y_ptr,x_ptr,x_stride,y_stride,n_cols,BLOCK_SIZE: tl.constexpr,
):row_idx = tl.program_id(0)offsets = tl.arange(0, BLOCK_SIZE)mask = offsets < n_colsx_row_ptr = x_ptr + row_idx * x_stridey_row_ptr = y_ptr + row_idx * y_striderow = tl.load(x_row_ptr + offsets,mask=mask,other=-float("inf"),)row_minus_max = row - tl.max(row, axis=0)numerator = tl.exp(row_minus_max)denominator = tl.sum(numerator, axis=0)output = numerator / denominatortl.store(y_row_ptr + offsets,output,mask=mask,)

Python wrapper:

def softmax(x):n_rows, n_cols = x.shapeBLOCK_SIZE = triton.next_power_of_2(n_cols)y = torch.empty_like(x)grid = (n_rows,)softmax_kernel[grid](y,x,x.stride(0),y.stride(0),n_cols,BLOCK_SIZE=BLOCK_SIZE,num_warps=8,)return y

这个版本的逻辑非常直接:

一个 program 处理一行

也就是:

program 0 -> 处理第 0 行
program 1 -> 处理第 1 行
program 2 -> 处理第 2 行
...

在每个 program 内部:

  1. tl.arange 生成列偏移;
  2. tl.load 读入一行;
  3. tl.max 做行内最大值;
  4. tl.exp 计算指数;
  5. tl.sum 做行内求和;
  6. tl.store 写回结果。

20. Persistent Program 是什么?

上面的 softmax 是普通写法:

一个 row 一个 program

如果 n_rows 很大,比如 100000,那么会启动 100000 个 programs。

另一种写法是 persistent program:

只启动接近 GPU 并发能力数量的 programs
每个 program 循环处理多行

例如:

@triton.jit
def softmax_kernel_persistent(y_ptr,x_ptr,x_stride,y_stride,n_rows,n_cols,BLOCK_SIZE: tl.constexpr,
):pid = tl.program_id(0)num_programs = tl.num_programs(0)offsets = tl.arange(0, BLOCK_SIZE)mask = offsets < n_colsfor row_idx in tl.range(pid, n_rows, num_programs):x_row_ptr = x_ptr + row_idx * x_stridey_row_ptr = y_ptr + row_idx * y_striderow = tl.load(x_row_ptr + offsets,mask=mask,other=-float("inf"),)row_minus_max = row - tl.max(row, axis=0)numerator = tl.exp(row_minus_max)denominator = tl.sum(numerator, axis=0)output = numerator / denominatortl.store(y_row_ptr + offsets,output,mask=mask,)

假设:

n_rows = 100000
num_programs = 320

那么:

program 0 -> row 0, 320, 640, ...
program 1 -> row 1, 321, 641, ...
program 2 -> row 2, 322, 642, ...
...

这种写法叫 persistent style。

它的核心思想是:

不为每一行都创建一个 program,而是创建一批长期运行的 programs,让它们持续从任务列表中取活干。

这样可以减少调度开销,并让 SM 持续保持忙碌。

21. 结合硬件信息估算 Program 数量

在一些 Triton 教程或高性能代码中,你会看到类似逻辑:

device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]

这些是在读取 GPU 硬件属性。

然后可能会做:

kernel = softmax_kernel.warmup(...)
kernel._init_handles()n_regs = kernel.n_regs
size_smem = kernel.metadata.sharedoccupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)

这里的逻辑是:

  1. 先预编译 kernel;
  2. 拿到 kernel 的寄存器使用量;
  3. 拿到 kernel 的 shared memory 使用量;
  4. 根据硬件资源估算每个 SM 能驻留多少 programs;
  5. 总 program 数设置为:
NUM_SM * occupancy

这就是 persistent program 里常见的启动规模估算方式。

它的含义是:

每个 SM 上放 occupancy 个 programs
GPU 总共有 NUM_SM 个 SM
所以一共启动 NUM_SM * occupancy 个 programs

当然真实 occupancy 还受其他硬件限制影响,这里只是一个实用估算。

22. Triton 里的 Kernel、Program、Warp、Thread、SM 的关系

可以用下面这张图理解:

一次 kernel launch|v
启动一个 grid|v
grid 里面有很多 Triton programs|v
programs 被 GPU 调度到不同 SM 上|v
每个 program 使用若干 warps|v
每个 warp 包含多个 threads|v
threads 执行底层指令

从 Triton 程序员视角看:

你主要控制:
- grid 有多少 programs
- 每个 program 处理什么 tile
- 每个 program 的 BLOCK_SIZE/BLOCK_M/BLOCK_N/BLOCK_K
- 每个 program 用多少 num_warps
- pipeline 用多少 num_stages

从 GPU 硬件视角看:

硬件负责:
- 把 programs 分配到 SM
- 以 warp 为单位调度执行
- 管理 registers、shared memory、global memory access

所以 Triton 代码的优化,本质上是在调这几个层级之间的关系:

数据 tile 大小-> program 资源占用-> register / shared memory pressure-> occupancy-> SM 利用率-> kernel 性能

23. Memory Coalescing:访存是否连续

写 Triton kernel 时,除了关注 occupancy,还要关注 memory access pattern。

例如:

offsets = block_start + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)

这是连续访问:

x[0], x[1], x[2], ..., x[1023]

这种访问通常比较高效,因为 GPU 可以合并内存访问,也就是 coalesced memory access。

如果你写成:

offsets = block_start + tl.arange(0, BLOCK_SIZE) * stride

stride 很大时,每个 lane 访问的位置相隔很远,访存效率可能下降。

在 GPU 上,很多 kernel 的瓶颈不是计算,而是显存访问。

例如向量加法:

c = a + b

每个元素只做一次加法,却要读 a、读 b、写 c。算术强度很低,通常是 memory-bound。

而矩阵乘法:

C = A @ B

每个元素会参与大量乘加,算术强度较高,更可能是 compute-bound。

所以 Triton 优化时要先判断 kernel 类型:

类型 瓶颈
elementwise 通常 memory bandwidth
softmax/layernorm memory + reduction
matmul compute / Tensor Core / data reuse
attention memory + compute + tiling

24. Kernel Fusion:Triton 的重要价值

Triton 的一个重要用途是做 kernel fusion。

例如:

y = torch.relu(x * scale + bias)

如果用 PyTorch 朴素写法,可能产生多个中间 tensor:

tmp1 = x * scale
tmp2 = tmp1 + bias
y = torch.relu(tmp2)

这会多次读写 global memory。

用 Triton 可以写成一个 kernel:

x = tl.load(...)
out = tl.maximum(x * scale + bias, 0.0)
tl.store(...)

中间结果不写回 global memory,而是保存在 register 里。

这可以显著减少 memory traffic。

对于深度学习推理,kernel fusion 非常重要。例如:

  • bias + activation fusion;
  • residual add + layernorm fusion;
  • RMSNorm fusion;
  • rotary embedding fusion;
  • attention 中的 block-wise softmax fusion;
  • quantization/dequantization fusion。

Triton 的优势就在于:你可以用相对简单的代码快速实现这些融合 kernel。

25. Matmul 中 Triton 的 Tile 思想

虽然初学 Triton 可以从 elementwise 和 softmax 入手,但真正体现 Triton 威力的是 matmul。

矩阵乘法:

\[C = A B \]

其中:

A: [M, K]
B: [K, N]
C: [M, N]

Triton 里通常让一个 program 计算 C 的一个 tile:

C_tile: [BLOCK_M, BLOCK_N]

然后沿 K 维循环:

A_tile: [BLOCK_M, BLOCK_K]
B_tile: [BLOCK_K, BLOCK_N]

不断累加:

\[C_{tile} += A_{tile} \times B_{tile} \]

伪代码类似:

@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr,M: tl.constexpr,N: tl.constexpr,K: tl.constexpr,BLOCK_M: tl.constexpr,BLOCK_N: tl.constexpr,BLOCK_K: tl.constexpr,
):pid_m = tl.program_id(0)pid_n = tl.program_id(1)offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)offs_k = tl.arange(0, BLOCK_K)acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)for k in range(0, K, BLOCK_K):a = tl.load(a_ptr + offs_m[:, None] * K + (k + offs_k[None, :]))b = tl.load(b_ptr + (k + offs_k[:, None]) * N + offs_n[None, :])acc += tl.dot(a, b)tl.store(c_ptr + offs_m[:, None] * N + offs_n[None, :], acc)

上面代码中的offs_m[:, None] * K + (k + offs_k[None, :])会广播成一个二维地址矩阵。

这个例子里,一个 program 处理二维 tile,而不是一维 vector。

这时你会看到更多 meta-parameters:

BLOCK_M
BLOCK_N
BLOCK_K
num_warps
num_stages

它们共同影响:

  • 每个 program 计算多少输出元素;
  • A/B 数据复用率;
  • Tensor Core 使用效率;
  • register pressure;
  • shared memory usage;
  • occupancy。

26. 为什么 Triton 性能调优经常要 Autotune?

因为高性能 kernel 的参数没有固定答案。

例如 matmul 里:

BLOCK_M = 16 / 32 / 64 / 128
BLOCK_N = 16 / 32 / 64 / 128
BLOCK_K = 32 / 64 / 128
num_warps = 4 / 8
num_stages = 3 / 4 / 5

不同 GPU、不同 shape、不同 dtype,最优参数都可能不一样。

Triton 支持 autotune,例如:

@triton.autotune(configs=[triton.Config({"BLOCK_SIZE": 1024},num_warps=4,),triton.Config({"BLOCK_SIZE": 2048},num_warps=8,),],key=["n_elements"],
)
@triton.jit
def kernel(...):...

Autotune 的思想是:

给定一组选项,让 Triton 在实际运行中 benchmark 不同配置,选择性能最好的那个。

27. Triton 编程中常见的几个判断问题

读一个 Triton kernel 时,可以按下面的问题检查。

1. 一个 program 处理什么?

这是最重要的问题。

例如:

vector add: 一个 program 处理 BLOCK_SIZE 个元素
softmax: 一个 program 处理一行
matmul: 一个 program 处理 C 的一个 tile

只要这个问题没搞清楚,后面的 program_idoffsetsmask 都会混乱。

2. grid 启动了多少 programs?

看 launch 代码:

kernel[grid](...)

如果:

grid = (n_rows,)

那就是一行一个 program。

如果:

grid = (ceil(M / BLOCK_M), ceil(N / BLOCK_N))

那就是二维 tile grid。

3. 每个 program 内部的数据 shape 是什么?

看:

tl.arange(0, BLOCK_SIZE)

或者:

offs_m[:, None]
offs_n[None, :]

如果出现:

offs_m[:, None] * stride_m + offs_n[None, :] * stride_n

说明在构造二维 tile。

4. mask 是否正确?

所有可能越界的 load/store 都应该有 mask。

尤其是:

  • 最后一个 block;
  • 非 2 的幂列数;
  • 矩阵边缘 tile;
  • batch size 不整除 block size。

5. 访存是否连续?

看指针表达式:

ptr + offsets

通常连续。

如果是:

ptr + offsets * stride

要注意 stride 是否很大。

6. 是否有不必要的 global memory 读写?

Triton 的优势是融合。要尽量让中间结果留在 register 或 SMEM 中,而不是写回 global memory。

28. 常见错误:把 Triton 当 CUDA 写

很多初学者会写出类似下面的代码:

@triton.jit
def bad_kernel(a, b, c, n_elements, BLOCK_SIZE: tl.constexpr):pid = tl.program_id(0)begin = pid * BLOCK_SIZEend = begin + BLOCK_SIZEfor i in tl.range(begin, end):c[i] = a[i] + b[i]

这种写法不符合 Triton 的优势。

Triton 更推荐向量化写法:

@triton.jit
def good_kernel(a, b, c, n_elements, BLOCK_SIZE: tl.constexpr):pid = tl.program_id(0)offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)mask = offsets < n_elementsa_vals = tl.load(a + offsets, mask=mask, other=0.0)b_vals = tl.load(b + offsets, mask=mask, other=0.0)c_vals = a_vals + b_valstl.store(c + offsets, c_vals, mask=mask)

区别在于:

错误思路:一个 program 里用 for 循环逐元素处理
正确思路:一个 program 里用 tl.arange 构造向量,一次处理一组元素

当然 Triton 中也可以用 tl.range,但它通常用于沿 K 维循环、persistent program 循环、pipeline 循环,而不是代替向量化写法逐元素操作。

29. Triton 和 CUDA 概念对照

可以粗略建立如下对应关系:

CUDA 概念 Triton 中的近似概念 说明
kernel @triton.jit 函数 GPU 上执行的函数
grid kernel[grid](...) 决定启动多少 programs
block / CTA program 粗略类比,但不完全等价
blockIdx tl.program_id(axis) 当前 program 的编号
threadIdx 通常不直接使用 Triton 用向量化表达
blockDim BLOCK_SIZE / tile size 每个 program 处理的数据规模
warp num_warps 控制 每个 program 使用多少 warps
shared memory SMEM / compiler-managed shared 可显式或隐式参与
registers kernel 临时变量占用 影响 occupancy
occupancy programs/warps per SM 影响 latency hiding

这张表只是帮助理解,不要机械地一一对应。

Triton 的抽象层级比 CUDA 更高。你主要关注 program 级别的数据划分,而不是 thread 级别的执行细节。

http://www.gsyq.cn/news/1551143.html

相关文章:

  • PS810电量计配置与通信接口实战:从核心参数到I2C/HDQ避坑指南
  • 行业内评价高的FPC贴合设备厂家推荐排行榜2026 - 品牌排行榜
  • FIFA 23 Live Editor完全指南:打造你的专属足球世界
  • SLAM Toolbox终极教程:掌握ROS 2D SLAM的7个实战技巧与5大核心优势
  • 嵌入式开发基础:SysDS Loader与Picobug监控程序实战解析
  • LiveSplit:速通玩家的终极计时器,让每一秒都精准掌控 [特殊字符]⏱️
  • ctfshow 无字母数字代码执行
  • EasyLPAC:5个关键步骤掌握专业级eUICC智能卡管理工具
  • 2026年宁波AI推广服务商实测盘点与合规推荐 - 起跑123
  • 终极指南:使用urdf-viz轻松实现机器人URDF文件可视化
  • AI公平性工程新范式:因果推断与合规落地实战
  • AI创业五大致命陷阱:从需求失焦到数据枯竭的实战避坑指南
  • 3分钟搞定小爱音箱音乐服务:DID配置的终极完整指南 [特殊字符]
  • 嵌入式AEC算法库解析:从NLMS原理到DSP工程实践
  • 黑苹果新手福音:3大核心功能揭秘OpCore Simplify的智能化配置革命
  • MC68HC16Y3 SCI模块深度解析:从UART原理到工业通信实战
  • 【Springboot毕设全套源码+文档】基于Java+springboot自行车租赁系统(丰富项目+远程调试+讲解+定制)
  • 终极指南:让老旧Mac焕发新生,免费升级最新macOS系统
  • 4层架构重构:构建企业级可视化ETL数据集成平台
  • pd.read_html实战避坑指南:HTML表格解析的三大陷阱与生产级解决方案
  • 深度解析roop-unleashed:无训练AI换脸技术的架构设计与实践指南
  • Selenium UI自动化测试环境搭建:Python+ChromeDriver实战指南
  • TWR-WIFI-AR4100评估板硬件手册深度解析与嵌入式Wi-Fi集成实战
  • Gemini Ultra技术解析:统一多模态、确定性推理与云边端协同架构
  • 构建可复现的GPU大模型训练机:A100+EPYC分布式基础设施实践
  • 国产化环境下的kkFileView实战指南:ARM架构文件预览服务部署与优化
  • 终极指南:如何在Windows 10上免费安装Windows Subsystem for Android
  • Microchip 93系列EEPROM选型指南:从命名规则到实战应用
  • OpCore Simplify:3个关键步骤让黑苹果配置从复杂变简单
  • 三相升流与单相逐相测试的差异