The problem it solves
The attention operation at the heart of every transformer is, surprisingly, often limited
not by how much arithmetic the GPU can do but by how fast it can move data. Standard
attention computes a large sequence × sequence score matrix, applies softmax, and
multiplies by the values. Writing that big intermediate matrix out to the GPU’s
high-bandwidth memory (HBM) and reading it back repeatedly dominates the runtime. The
operation is memory-bandwidth-bound: the GPU’s compute units sit waiting on data. As
context lengths grow, this both slows things down and consumes memory that scales with the
square of the sequence length.
The key idea: be IO-aware
Flash Attention, introduced in 2022, reframes attention as an IO-aware problem: the goal is to minimise traffic between fast on-chip SRAM and slow HBM. It does this by tiling — splitting the queries, keys, and values into blocks that fit in the GPU’s small but fast on-chip memory, and computing attention block by block. Crucially, it never materialises the full attention matrix in HBM. Instead it uses an online softmax trick that processes blocks incrementally while keeping running statistics, so the correct normalised result is produced without ever holding the whole matrix at once.
Why it is exact, not approximate
This is the property that makes Flash Attention special. Many earlier speedups — sparse attention, low-rank approximations, linear attention — buy speed by computing an approximate version of attention, which can change model quality. Flash Attention does not. Its tiling and online-softmax produce a result that is mathematically identical to standard attention, bit-for-bit equivalent up to normal floating-point ordering. You get the speed and memory savings with zero change to what the model computes, which is why it can be dropped into existing models and training runs safely.
The payoff
By avoiding HBM round-trips, Flash Attention typically delivers 2–4× faster attention and reduces attention memory use dramatically — the often-quoted figure is 5–20× less memory, because memory no longer scales with the full quadratic attention matrix. This has two big consequences. First, training and inference get cheaper and faster. Second, and more importantly for modern models, it makes long contexts practical: sequences of tens or hundreds of thousands of tokens that would otherwise blow up memory become feasible, which is why long-context LLMs rely on it.
Versions and where it fits
The original Flash Attention was followed by Flash Attention 2, which kept the exact- attention guarantee but reworked how the computation is partitioned across GPU threads and warps, cut down non-matrix-multiply work, and parallelised better along the sequence dimension — yielding roughly another 2× speedup and getting close to peak GPU utilisation. Later iterations target newer GPU architectures and features. In practice, Flash Attention is now built into mainstream deep-learning frameworks and is enabled by default in many training and serving stacks. It pairs naturally with sparse attention for extreme context lengths: Flash Attention makes dense attention as cheap as possible, and sparse patterns reduce how much attention you compute at all.