CSE_lecture25:GPUs and FlashAttention
More on GPUs, tiling and FlashAttention
case study: GPU
GPU的基础单元为SM,每个SM上有多个warp,在同一个warp上都会执行同一个指令,有32个SIMD线程
现代GPU,如H100,引入了tensor core,即脉冲阵列;同时core变得更具体,定制浮点数计算;还降低了用于图形学的SFU数量。但整体架构没有很大区别
SIMT使得多核程序也能在GPU上运行,thread会group成thread block,使得它们能够运行在同一个SM上,便于同步;一个grid(一个应用)包含了多个thread block
GPU的thread数量被定死,但用户可以使用任意数量的thread,只需要将其划分成多个thread block,GPU会根据SM数均分thread block;而每个thread block也会被拆分成多个wrap
GPU的存储为memory hierarchy,每个wrap有自己的register file和自己的cache,访问register的cycle数少,而访问cache和DRAM会逐层增多
在同一个warp的所有thread共享相同的program counter,因此这些线程同一时间只能做同一件事。当遇到control flow时,需要使用masking,但可能导致CUDA程序出现死锁问题
现代GPU不会再出现死锁问题,因为每个ALU有自己的PC,但这些thread还是共享相同的decoder,在不同thread可以执行不同分支的情况下,尽可能让它们还是跑同一指令
thread block的优点在于可以让里面的thread共享SM上的shared memory,同时可以方便进行thread之间的同步。而访问global memory将会很耗时,即cycle数很多的同时bandwidth也很小,同时load unit很少
example revisited: vector add
向量相加很适合GPU进行并行计算。为了知道当前线程操作的数据位置i,需要确定自己是哪个thread block idx的第几个线程:
1 | |
但这种计算方式性能很差,即受限于访存
两个向量的存储有两种类型,行存和列存:
由于一次读取会将多个byte读取上来,因此列存的访存次数更少,缺点是需要进行一次列存的转换
optimizing attention computation: analysis from a memory access perspective
除了softmax的核心均为矩阵乘法,因此重点优化矩阵乘法,而矩阵乘法的FLOPS数量级为 $O(n^3)$
受限于带宽,在访存时为了加速,尽量访问SRAM(如registered files),而减少HBM(如global memory)的访问
对于矩阵A(N*K)和B(K*M)相乘得到C(N*M),要尽可能控制内存访问。假设一次内存读取会读上来BB个bytes,且为行存
1 | |
这个简单的实现没有利用一次读取BB个bytes的特性,因此的数量级为 $O(NMK*BB)$
1 | |
这个代码利用了访存优化,但受限于B矩阵按列读取,故数量级依然为 $O(NMK*BB)$
1 | |
通过交换j和k的循环,使得B矩阵也是按行读取的,从而可以利用到访存优化,数量级降低到 $O(NMK)$
使用tiling进一步降低访存次数,即将矩阵乘法进行分块
1 | |
此时访存次数为:$O(\frac{N}{BN} \times \frac{M}{BM} + 2 \times \frac{N}{BN} \times \frac{M}{BM} \times \frac{K}{BK})$,进一步下降。当然这些数据应当放在SRAM上,因为SRAM有20MB,足够放下
tiling可以更改i, j, k的顺序,其改变的是分母的因子
使用auto tuning来选择最为合适的BM, BN, BK
在attention中,由于N远大于d,在tiling情况下,认为矩阵乘法和softmax的访存复杂度约为 $O(n^2)$,因此性能受限于上下文长度
flash attention进一步优化复杂度,即对attention的多步操作进行合并
忽略softmax,则只用考虑 $QK^TV$,这可以进一步优化复杂度:
1 | |
假设GPU缓存量大小为:$M \sim = B_Q * d$,那么复杂度为:$O(\frac{N^2d^2}{M})$,而d往往远小于B_Q
现在将softmax考虑进来,使用online softmax来优化归一化需要整行值的问题