在 2024 年国际机器学习大会(ICML)上,Horace He 宣布推出了一种新的 PyTorch API——FlexAttention。FlexAttention 结合了 Torch compile、Triton 和 FlashAttention 算法,使得 GPU 在处理各种注意力机制时表现更加出色。Horace He 在社交媒体上表示:“用户不再受限于软件实现的随机性,能够在几行 PyTorch 代码中享受高效的注意力计算。”
FlashAttention-3 是最新发布的注意力机制算法,旨在提高 Transformer 模型的计算效率和精度。该算法通过优化硬件利用率和低精度计算,提升了模型的性能。FlashAttention-3 的背景可以追溯到 FlashAttention 的初代版本。初代 FlashAttention 通过减少内存读写操作,在 GPU 上加速了注意力计算。然而,随着硬件技术的进步,FlashAttention-2 在 H100 GPU 上的利用率仅为 35%。为了解决这一问题,研究团队开发了 FlashAttention-3,利用 Hopper GPU 的新功能,实现了更高的性能。
FlashAttention-3 采用了三项主要技术来加速 Hopper GPU 上的注意力计算:
- 利用 Tensor Cores 和 TMA 的异步性,通过 warp-specialization 重叠整体计算和数据移动。
- 交错块状矩阵乘法和 softmax 操作。
- 块量化和不一致处理,利用硬件对 FP8 低精度的支持。
这些技术使得 FlashAttention-3 在 H100 GPU 上的性能提升了 1.5-2.0 倍,FP16 达到了 740 TFLOPs/s(75% 利用率),FP8 接近 1.2 PFLOPs/s。此外,FP8 FlashAttention-3 的数值误差比基线 FP8 注意力低 2.6 倍。
NVIDIA 与 Colfax、Together.ai、Meta 和普林斯顿大学合作,利用 Hopper GPU 架构和 Tensor Cores,加速关键的融合注意力内核。通过使用 CUTLASS 3,FlashAttention-3 在 FP16 下的性能比 FlashAttention-2 提高了 1.5-2.0 倍,达到了 740 TFLOPs。在 FP8 下,FlashAttention-3 达到了 1.2 PFLOPs,数值误差比基线 FP8 注意力低 2.6 倍。
FlexAttention 的技术细节包括以下几个方面:
- 融合内核:FlexAttention 通过融合内核,将多个注意力机制的计算步骤合并为一个操作,从而减少了内存读写和计算开销。
- Torch compile:FlexAttention 利用了 Torch compile 的优化功能,使得计算过程更加高效。Torch compile 通过在芯片上处理数据块,减少了数据在芯片和内存之间的往返流动,从而加速了计算。
- Triton:FlexAttention 结合了 Triton 的计算能力,使得 GPU 在处理注意力机制时能够发挥其性能。Triton 提供了一种灵活的编程模型,使得用户可以编写高效的 GPU 代码。
与 FlashAttention-3 相比,FlexAttention 提供了更高的灵活性和易用性。FlashAttention-3 专注于优化特定的硬件架构和低精度计算,而 FlexAttention 则通过融合内核和计算框架,实现了对多种注意力机制的支持。
在实际应用中,FlexAttention 已经展示了其性能。例如,在高分辨率图像处理任务中,FlexAttention 通过调整注意力计算的粒度和范围,降低了计算和内存需求。在自然语言处理任务中,FlexAttention 通过融合内核和计算框架,实现了对长上下文的处理。