
作者:企鹅火烈鸟
视频资源:
https://www.bilibili.com/video/BV11tMwznEmo?spm_id_from=333.788.videopod.sections&vd_source=ed07f7f3d6eb2ac1008f77eaff3aaab0
前言
最近在回复 sglang issue的时候,sgl-kernel中的w4a8 kernel会偶遇精度问题。像是:
# Compare results
try:
> torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 28659 / 114688 (25.0%)
E Greatest absolute difference: 1.5625 at index (6, 6778) (up to 0.1 allowed)
E Greatest relative difference: 258048.0 at index (2, 4016) (up to 0.01 allowed)
test_cutlass_w4a8_moe_mm.py:246: AssertionError
------------------------------ Captured stdout call ------------------------------
Testing with batch_size=16, k=256, n=7168, num_experts=8
FAILURE: tensors are NOT close.
Max absolute difference: 1.5625
Mean absolute difference: 0.07568359375
AssertionError: Tensor-likes are not close!
问题发生的很诡异,后与社区同学一起再次review sglang代码。并且在网上通过资料学习之后解决了这一小问题,简单来说。这是一次因为异步操作而没控制好同步带来的问题,后续我将总结在nv openday 视频里学到的知识来尝试分析w4a8的kernel,并给出导致这个问题的原因。
前排提示,本文不会讲解过多的cute相关的代数知识。
w4a8 kernel 在 Hopper架构上的实现原理
kernel实现简介
顾名思义,w4a8 kernel是一个激活8bit,权重4bit的kernel实现。它通过将权重4bit反量化回8bit,使用Hopper上的GMMA实现矩阵乘法运算。在这个例子中,A是per tensor量化。B为per channel量化,细节如图上图所示。下图展示了w4a8 kernel实现的大体流程。

任务切分
在深入kernel细节之前,让我们来看看对于一次矩阵计算。我们是如何根据TileShape进行任务切分的:对于GroupedGEMM中的Group 0来说,mma操作往往对K维度进行切分。也就是B Tensor的横向,通过构建Tile这样一个概念,我们把C Tensor中Block 0 这一块的计算拆分成多个Stage进行计算,来减轻每一次计算的workload

Stage之后,我们针对每一个Stage还可以进行更深入的切分:如下图所示,对应每一个Stage。我们还可以拆分出Block的概念,在block上就会执行实际的WGMMA计算。
更直接的讲,切分不同的Stage的目的。是为了在多stage间让TMA的数据搬运和MMA计算有overlap的效果,而切分多block是为了share mem的搬运和WGMMA之间做overlap。总之理想情况下,在矩阵运算时。我们希望mma的计算是没有overhead的,而通过多级的overlap可以趋近这一点。

Mainloop实现
之后我们再深入一些,看看w4a8 kernel内部更细致的mainloop实现流程。在这个例子里,我们默认K维度的tile切分为128。使用K维度为32的tensorcore进行计算。由此我们可以得知,在该例子里需要4个block来计算一个k Tile。
整体流程图如下所示:
首先mainloop会先发射一个pipeline_consumer_wait的指令,它可以确保在gemm流程中TMA的数据已经加载完成。 每个block会先load进一个MK的矩阵,然后对这个矩阵做int4 -> fp8的dequant。 最后做mma的操作(在这里使用了swapAB的思路,所以是load MK矩阵然后做dequant) 不同stage之间,也通过pipeline_consumer_wait确保数据的加载。 在Stage1将要计算mma之前,我们需要一个wait指令来确保stage0的mma已计算完毕。因此在这里我们需要一个warpgroup_wait<3>() 的指令来兜底。这里的3代表有3条mma还没结束,对这个例子来说也就是第一条mma的计算已经结束了。

Overlap
更进一步地说,我们可以实现这样的流程,也就是上一个block的mma计算可以和这个block的 load 做overlap。

Tensor core 和 swap AB
让我们展开每一个mma,来深入其中的细节。并且分析swap AB在 w4a8 kernel上的意义。对于Hopper架构的WGMMA而言,它的输入A可以是从寄存器或者shared mem上拿到。输入B必须从shared mem上拿到,因此对于weight需要做dequant的情形来看。我们让weight作为A输入,而activation作为B输入。可以让A通过cudacore进行反量化计算,而直接通过寄存器传入。避免了weight作为B传入而需要先写回shared mem的问题。

Epilogue 和 scale
省略掉了视频中的一些algebra知识,让我们来看最后一段。当我们完成了一段stage的计算之后,还差最后一步我们没有进行计算。那就是w4a8 kernel的scale。
hopper的mma是一个异步的指令,当我们发射完四次的mma之后。并不能认为所有的mma已经执行完毕而直接进行scaling,同上文所述。我们仍然需要一个wait操作来保证前四个mma操作已经执行完毕并得到正确结果。
因此在这里我们使用:
warpgroup_wait<0>();
来确保一个stage中的mma已都执行完毕,从而再进行scaling。同时因为在这里我们已经执行完四个mma的计算,我们可以直接使用 consumer_release来释放掉buffer。
大家可能也猜到了,在SGLang的w4a8 kernel里。也是丢掉了这样一个同步,让最后的结果执行错误了。

SGLang case
在SGLang中,社区同学通过这个PR修复了scaling的问题:
https://github.com/sgl-project/sglang/pull/10572

我们可以发现,在特定的位置添加上wait之后。就能修复这个问题。
写在最后
在Hopper架构的编程模型里,我们可能时常遇到因为异步操作而带来的问题。
比如scale明明没问题,但是结果却算不对。当我去掉scale之后,结果反而对了。
其实这些问题也都有迹可循,重要的是仔细操作异步操作中的同步控制。
-- 完 --
机智流推荐阅读:
1. 万字长文解答为何LLM同问不同答?OpenAI前CTO团队最新研究让大模型结果可复现
2. VLA-Adapter:北邮等团队以0.5B参数实现机器人智能新高度,还无需预训练
3. 理解和生成让任务真的能相互受益吗,还是仅仅共存?北大&百度UAE框架,统一视觉理解与生成,实现多模态模型新突破
4. 聊聊大模型推理系统之Q-Infer技术突破:GPU-CPU协同推理提速3倍背后的三大创新
关注机智流并加入 AI 技术交流群,不仅能和来自大厂名校的 AI 开发者、爱好者一起进行技术交流,同时还有HuggingFace每日精选论文与顶会论文解读、Talk分享、通俗易懂的Agent知识与项目、前沿AI科技资讯、大模型实战教学活动等。
cc | 大模型技术交流群 hf | HuggingFace 高赞论文分享群 具身 | 具身智能交流群 硬件 | AI 硬件交流群 智能体 | Agent 技术交流群