CSE_lecture25:GPUs and FlashAttention

More on GPUs, tiling and FlashAttention

case study: GPU

GPU的基础单元为SM,每个SM上有多个warp,在同一个warp上都会执行同一个指令,有32个SIMD线程

现代GPU,如H100,引入了tensor core,即脉冲阵列;同时core变得更具体,定制浮点数计算;还降低了用于图形学的SFU数量。但整体架构没有很大区别

SIMT使得多核程序也能在GPU上运行,thread会group成thread block,使得它们能够运行在同一个SM上,便于同步;一个grid(一个应用)包含了多个thread block

GPU的thread数量被定死,但用户可以使用任意数量的thread,只需要将其划分成多个thread block,GPU会根据SM数均分thread block;而每个thread block也会被拆分成多个wrap

GPU的存储为memory hierarchy,每个wrap有自己的register file和自己的cache,访问register的cycle数少,而访问cache和DRAM会逐层增多

在同一个warp的所有thread共享相同的program counter,因此这些线程同一时间只能做同一件事。当遇到control flow时,需要使用masking,但可能导致CUDA程序出现死锁问题

现代GPU不会再出现死锁问题,因为每个ALU有自己的PC,但这些thread还是共享相同的decoder,在不同thread可以执行不同分支的情况下,尽可能让它们还是跑同一指令

thread block的优点在于可以让里面的thread共享SM上的shared memory,同时可以方便进行thread之间的同步。而访问global memory将会很耗时,即cycle数很多的同时bandwidth也很小,同时load unit很少

example revisited: vector add

向量相加很适合GPU进行并行计算。为了知道当前线程操作的数据位置i,需要确定自己是哪个thread block idx的第几个线程:

1
2
3
4
5
6
__global__ void vecAddKernal(const float* A, const float* B, float* C, int n){
i = blockDim.x + blockIdx.x + threadIdx.x
if (i < n){
C[i] = A[i] + B[i];
}
}

但这种计算方式性能很差,即受限于访存

两个向量的存储有两种类型,行存和列存:

由于一次读取会将多个byte读取上来,因此列存的访存次数更少,缺点是需要进行一次列存的转换

optimizing attention computation: analysis from a memory access perspective

除了softmax的核心均为矩阵乘法,因此重点优化矩阵乘法,而矩阵乘法的FLOPS数量级为 $O(n^3)$

受限于带宽,在访存时为了加速,尽量访问SRAM(如registered files),而减少HBM(如global memory)的访问

对于矩阵A(N*K)和B(K*M)相乘得到C(N*M),要尽可能控制内存访问。假设一次内存读取会读上来BB个bytes,且为行存

1
2
3
4
5
6
7
for i in 0..N:
for j in 0..M:
for k in 0..K:
HBM_read(A[i][k])
HBM_read(B[k][j])
C[i][j] += A[i][k] * B[k][j]
HBM_write(C[i][j])

这个简单的实现没有利用一次读取BB个bytes的特性,因此的数量级为 $O(NMK*BB)$

1
2
3
4
5
6
7
8
9
for i in 0..N:
for j in 0..M:
for k in 0..K:
if k % BB == 0:
HBM_read(A[i][k]) # will read A[i][k:k+BB]
HBM_read(B[k][j]
C[i][j] += A[i][k] * B[k][j]
if j % BB == 0:
HBM_write(C[i][j]) # will write C[i][j:j+BB]

这个代码利用了访存优化,但受限于B矩阵按列读取,故数量级依然为 $O(NMK*BB)$

1
2
3
4
5
6
7
8
9
10
for i in 0..N:
for k in 0..K:
HBM_read(A[i][k])
for j in 0..M:
if j % BB == 0:
HBM_read(B[k][j])
HBM_read(C[i][j])
C[i][j] += A[i][k] * B[k][j]
if j % BB == 0:
HBM_write(C[i][j])

通过交换j和k的循环,使得B矩阵也是按行读取的,从而可以利用到访存优化,数量级降低到 $O(NMK)$

使用tiling进一步降低访存次数,即将矩阵乘法进行分块

1
2
3
4
5
6
7
for i in 0..N/BN.step(BN):
for j in 0..M/BM.step(BM):
for k in 0..K/BK.step(BK):
load(A[i:i+BN][k:k+BK])
load(B[k+BK][j:j+BM])
C[i:i+BN][j:j+BM] += A[i:i+BN][k:k+BK] * B[k+BK][j:j+BM]
write(C[i:i+BN][j:j+BM])

此时访存次数为:$O(\frac{N}{BN} \times \frac{M}{BM} + 2 \times \frac{N}{BN} \times \frac{M}{BM} \times \frac{K}{BK})$,进一步下降。当然这些数据应当放在SRAM上,因为SRAM有20MB,足够放下

tiling可以更改i, j, k的顺序,其改变的是分母的因子

使用auto tuning来选择最为合适的BM, BN, BK

在attention中,由于N远大于d,在tiling情况下,认为矩阵乘法和softmax的访存复杂度约为 $O(n^2)$,因此性能受限于上下文长度

flash attention进一步优化复杂度,即对attention的多步操作进行合并

忽略softmax,则只用考虑 $QK^TV$,这可以进一步优化复杂度:

1
2
3
4
5
6
7
8
9
for i in 0..N/BQ.step(BQ):  ## Iterate over the Q
load(Q[i:i + BQ , 0:d])
write(O[i:i + BQ , 0:d])
O[i:i + BQ , 0:d] = 0
for j in 0..N/BK.step(BK):
load(KT[0:d,j:j+ BK ])
load(V[j:j+ BK, 0:d])
O[i:i+BQ ,0:d] += Q[i:i+BQ,0:d] * KT[0:d,j:j+ BK ] * V[j:j+ BK, 0:d]
write(O[i:i + BQ , 0:d])

假设GPU缓存量大小为:$M \sim = B_Q * d$,那么复杂度为:$O(\frac{N^2d^2}{M})$,而d往往远小于B_Q

现在将softmax考虑进来,使用online softmax来优化归一化需要整行值的问题


CSE_lecture25:GPUs and FlashAttention
http://example.com/2026/01/01/CSE-lecture25-GPUs-and-FlashAttention/
作者
jietiDdd
发布于
2026年1月1日
许可协议