Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao
With an already established impact in the community and the speed at which the field is moving, seeing FlashAttention-3 at NeurIPS 2024 felt almost nostalgic. Nevertheless, its significant advances in computational efficiency for scaling transformer models means it was a well-deserved spotlight poster.
The paper demonstrates how algorithms need to be benchmarked and adapted when new hardware is released. FlashAttention-2 was highly efficient when written, achieving substantial speedups on Ampere GPUs, but the authors discovered that it did not translate well to Hopper GPUs. On an Nvidia H100, FlashAttention-2 only achieves 35% utilisation, meaning it is not very efficient.
The paper suggests three key changes, exploiting the asynchrony of Tensor Cores and Tensor Memory Access (TMA), that raise GPU utilisation to 75-85%:
- Overlapping overall computation and data movement via warp-specialisation
- Interleaving block-wise matrix multiplication and softmax operations
- Block quantisation and incoherent processing that leverages hardware support for FP8 low-precision
These adjustments reduce memory bottlenecks and better leverage hardware potential, giving practitioners who use transformers faster training, faster inference and the ability to handle longer contexts. The authors demonstrate how, by thinking about hardware constraints and applying software engineering principles, models can be optimised without changing their core architecture. With hardware innovations always on the horizon, I’m excited to see what future iterations will bring.
FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision