将Flash注意力机制强加到TPU上并从中学到的教训
Forcing Flash Attention onto a TPU and Learning the Hard Way

原始链接: https://archerzhang.me/forcing-flash-attention-onto-a-tpu

## TPU 与 GPU:深入探讨 Flash Attention 本文详细介绍了将为 GPU 开发的 Flash Attention 内核(在第 4 部分中)移植到 TPU 的尝试,揭示了令人惊讶的性能差异。虽然算法保持不变,但底层的硬件和编译器优化却极大地改变了结果。 最初移植到 JAX/TPU 的版本比融合的标准注意力实现要慢得多。这被追溯到 TPU 的架构:其矩阵乘法单元 (MXU) 专为平铺矩阵运算而设计,并且其大型片上存储器 (VMEM) 允许完整的注意力矩阵驻留在其中,用于较短的序列长度。最初的 JAX 实现,使用 `fori_loop`,阻碍了编译器优化并行性的能力。 切换到 `jax.vmap` – 信号查询块之间的独立性 – 释放了显著的性能提升,最终超越了融合的标准注意力,用于更长的序列。这突出了向编译器传达意图的重要性。关键要点是:TPU 硬件和 XLA 编译器通常会自动处理平铺和优化,使得手动干预变得不必要甚至有害。 进一步的研究表明,TPU 的收缩阵列设计本质上可以有效地执行平铺矩阵乘法。虽然通过 Pallas 进行自定义内核提供了细粒度的控制(DMA 流水线,显式内存放置),但对于许多用例,利用 XLA 的自动优化可以提供最佳性能。该项目强调,在 GPU 上有效的优化策略不一定对 TPU 有益,并且理解底层硬件对于编写高效代码至关重要。

这个Hacker News讨论围绕一篇博客文章,详细介绍了为Google的TPU优化“Flash Attention”的挑战。最初的文章强调了将代码适配到新硬件加速器上所涉及的大量工作——索引、分区和基准测试——这些工作通常已经由像Google这样的大公司为流行的操作完成。 一位评论员指出一个有趣的副作用:这篇博客文章*感觉*像是AI生成的。他们描述了过于冗长的句子、不一致的markdown格式(特别是过多的加粗)、以及突然出现表格,给读者造成了认知负担。评论员怀疑使用了LLM来撰写这篇文章,迫使他们不断评估内容,以区分实质内容和风格上的废话——这种干扰在技术写作中很少遇到。 本质上,这场讨论强调了硬件优化的技术障碍*以及*消费潜在由AI生成的内容所带来的微妙困难。
相关文章

原文

This is the fifth post in a series on LLM internals. Part 1 covered attention, Part 2 covered generation, Part 3 covered the Flash Attention algorithm, Part 4 put it on a GPU with Triton. This post takes the Triton kernel from Part 4 and ports it to a TPU.

There was a lot of good learning that came out of Part 4. But while working in Colab, I couldn’t help but notice that TPU was offered for free in the free tier. I figured — what if I just take Part 4’s flash attention and port it to TPU? I know the algorithm, I’ve written the kernel, JAX is just “numpy but compiled.” Translate, benchmark, call it a day.

It did not go that way.

The code uses JAX and runs on a TPU. To follow along, a free Colab TPU runtime works (Runtime -> Change runtime type -> TPU).


Contents


JAX/XLA: the TPU programming model

In Part 4, I wrote Triton kernels: explicit program_id, pointer arithmetic, tl.load/tl.store. The code controls exactly which bytes move where.

JAX, upon first glance, is a layer above that. Operations are expressed directly as matmul, exp, where — and the XLA compiler decides how to map them to hardware. When jax.jit is invoked, here’s what happens:

  1. JAX traces the Python function — runs it once with abstract values to record which ops happen
  2. The trace becomes HLO (High-Level Operations) — a graph of ~100 primitives like dot, reduce, broadcast
  3. XLA optimizes — most importantly, it fuses sequences of elementwise ops into single kernels so intermediates never hit HBM
  4. XLA compiles to device code — PTX for GPU, VLIW instructions for TPU

The Python isn’t running on the TPU. It’s a specification that gets compiled into a static binary.

The tradeoff: mutability is gone

Triton gives mutable pointers. tl.store(ptr, val) writes wherever the code wants. JAX arrays are immutable — there’s no out[i] = val.

Why? Because jax.jit traces the function into a pure computation graph. Mutation would create side effects that break tracing. This has concrete consequences for the flash attention loop:

Triton (Part 4) JAX (this post)
tl.store(out_ptrs, acc, mask=...) out = lax.dynamic_update_slice(out, tile, (start, 0))
for kv_start in range(0, q_end, BLOCK_KV): jax.lax.fori_loop(0, num_k_blocks, k_body, state)
Mutable acc += tl.dot(weights, v) Return new state: return (new_max, new_sum, new_acc)
Pointer arithmetic for tile addresses Compiler handles data movement

jax.lax.fori_loop: a Python for loop gets unrolled at trace time — 100 iterations creates 100 copies of the loop body in the computation graph. fori_loop tells XLA “this is a loop” so it compiles to an actual hardware loop. The tradeoff: the body must be a pure function that takes state in and returns state out.

dynamic_update_slice: returns a new array with a slice replaced. “Dynamic” means the start index can be a runtime value (like q_start), but the slice size must be known at compile time.

Enough context. Let me write the code.


Standard causal attention

Same baseline as Parts 3 and 4 — materializes the full (n, n) score matrix:

def standard_causal_attention(Q: jax.Array, K: jax.Array, V: jax.Array) -> jax.Array:
    """Standard causal attention. Shapes: Q, K, V: (n, d) -> out: (n, d)"""
    assert Q.ndim == K.ndim == V.ndim == 2
    assert Q.shape == K.shape == V.shape

    n, d = Q.shape
    scale = jnp.float32(1.0 / math.sqrt(d))

    q = Q.astype(jnp.float32)
    k = K.astype(jnp.float32)
    v = V.astype(jnp.float32)

    scores = (q @ k.T) * scale                           # (n, n)
    causal_mask = jnp.triu(jnp.ones((n, n), dtype=bool), k=1)
    scores = jnp.where(causal_mask, -jnp.inf, scores)   # (n, n)

    weights = jax.nn.softmax(scores, axis=-1)           # (n, n)
    out = weights @ v                                    # (n, d)
    return out.astype(Q.dtype)

standard_causal_attention_jit = jax.jit(standard_causal_attention)

Standard stuff. XLA sees the entire expression and fuses it into one optimized kernel — no intermediate matrices spilling to HBM. This is the baseline to beat.


Flash attention in JAX

Same algorithm as Part 3’s numpy version and Part 4’s Triton kernel. Same running state — running_max, running_sum, acc. Same per-tile update:

\[new\_max = \max(running\_max,\, \max(\text{scores}))\] \[rescale = e^{running\_max - new\_max}\] \[running\_sum \leftarrow rescale \cdot running\_sum + \textstyle\sum e^{\text{scores} - new\_max}\] \[acc \leftarrow rescale \cdot acc + e^{\text{scores} - new\_max} \cdot V\]

The algorithm is identical. What changes is how JAX’s functional model shapes the code.

@partial(jax.jit, static_argnames=("block_m", "block_n"))  # recompiles if block sizes change
def flash_attention_tiled(
    Q: jax.Array, K: jax.Array, V: jax.Array,
    block_m: int = 128, block_n: int = 128,
) -> jax.Array:
    """Causal Flash Attention with tiled online softmax in JAX.
    Same algorithm as Part 3 (numpy) and Part 4 (Triton).
    Lines marked # <-- JAX are where this diverges from the Triton version.
    """
    assert Q.ndim == K.ndim == V.ndim == 2
    assert Q.shape == K.shape == V.shape
    assert block_m > 0 and block_n > 0

    n, d = Q.shape
    q = Q.astype(jnp.float32)
    k_all = K.astype(jnp.float32)
    v_all = V.astype(jnp.float32)
    scale = jnp.float32(1.0 / math.sqrt(d))

    # (row_max, row_sum, accumulator) — the online softmax state
    SoftmaxState = tuple[jax.Array, jax.Array, jax.Array]

    # Pad so every dynamic_update_slice writes a full (block_m, d) chunk.
    # XLA needs static slice sizes — can't write a variable-length chunk.      # <-- JAX
    num_q_blocks = math.ceil(n / block_m)
    num_k_blocks = math.ceil(n / block_n)
    n_pad = num_q_blocks * block_m

    out = jnp.zeros((n_pad, d), dtype=jnp.float32)

    q_offsets = jnp.arange(block_m)
    k_offsets = jnp.arange(block_n)

    # Outer loop over query blocks.
    # fori_loop, not a Python for — otherwise XLA unrolls it at trace time.    # <-- JAX
    def q_body(q_block: int, out_buf: jax.Array) -> jax.Array:
        q_start = q_block * block_m
        q_idx = q_start + q_offsets                                            # (block_m,)
        q_mask = q_idx < n
        q_safe = jnp.minimum(q_idx, n - 1)                   # scalar broadcasts across vector

        q_tile = jnp.where(q_mask[:, None], q[q_safe, :], 0.0)             # (block_m, d)

        # Same running state as Part 3 and Part 4
        running_max = jnp.full((block_m,), -jnp.inf, dtype=jnp.float32)
        running_sum = jnp.zeros((block_m,), dtype=jnp.float32)
        acc = jnp.zeros((block_m, d), dtype=jnp.float32)

        # Inner loop over K/V blocks.
        # State is a tuple — fori_loop body takes it in and returns it out.    # <-- JAX
        def k_body(k_block: int, state: SoftmaxState) -> SoftmaxState:
            running_max, running_sum, acc = state

            k_start = k_block * block_n
            k_idx = k_start + k_offsets                                        # (block_n,)
            k_mask = k_idx < n
            k_safe = jnp.minimum(k_idx, n - 1)               # scalar broadcasts across vector

            k_tile = jnp.where(k_mask[:, None], k_all[k_safe, :], 0.0)     # (block_n, d)
            v_tile = jnp.where(k_mask[:, None], v_all[k_safe, :], 0.0)     # (block_n, d)

            scores = (q_tile @ k_tile.T) * scale                            # (block_m, block_n)

            causal = q_idx[:, None] >= k_idx[None, :]
            valid = q_mask[:, None] & k_mask[None, :] & causal
            scores = jnp.where(valid, scores, -jnp.inf)

            tile_max = jnp.max(scores, axis=1)                              # (block_m,)
            new_max = jnp.maximum(running_max, tile_max)

            rescale = jnp.where(
                jnp.isfinite(running_max),
                jnp.exp(running_max - new_max),
                0.0,
            )
            weights = jnp.where(
                jnp.isfinite(new_max)[:, None],
                jnp.exp(scores - new_max[:, None]),
                0.0,
            )                                                                # (block_m, block_n)

            running_sum = rescale * running_sum + jnp.sum(weights, axis=1)
            acc = rescale[:, None] * acc + weights @ v_tile

            return new_max, running_sum, acc                  # <-- JAX: return new state

        running_max, running_sum, acc = jax.lax.fori_loop(
            0, num_k_blocks, k_body, (running_max, running_sum, acc)
        )

        out_tile = jnp.where(running_sum[:, None] > 0, acc / running_sum[:, None], 0.0)

        # Can't do out_buf[q_start:, :] = out_tile — arrays are immutable.    # <-- JAX
        out_buf = jax.lax.dynamic_update_slice(out_buf, out_tile, (q_start, 0))
        return out_buf

    out = jax.lax.fori_loop(0, num_q_blocks, q_body, out)
    return out[:n, :].astype(Q.dtype)

What tripped me up

The algorithm is the same as Part 4’s Triton kernel. Here’s what actually changed.

No pointer arithmetic. In Triton, loading a tile meant computing a 2D grid of memory addresses: A_ptr + offs_row[:, None] * stride + offs_col[None, :]. In JAX, it’s q[q_safe, :] — normal array indexing. The compiler figures out the memory access pattern. This is the biggest readability win.

State goes in, state comes out. In Triton, acc is a mutable local variable — acc += tl.dot(weights, v) modifies it in place. In JAX, the fori_loop body is a pure function: takes (running_max, running_sum, acc) as input, returns updated versions. No mutation. I found this awkward at first, but it forces the code to be explicit about what state the loop carries — which is actually clarifying.

fori_loop is not optional. I initially wrote the outer loop as for q_block in range(num_q_blocks): and it compiled fine. But XLA unrolled every iteration into the graph, and compilation took forever for large sequences. fori_loop tells XLA this is a real loop. The tradeoff: the body must be a function, and there’s no breaking early. Part 4’s Triton kernel could stop the KV loop at q_end for causal early-stop. Here all K blocks get processed and the causal mask zeros out future positions — more wasted compute, but the loop structure stays simple for XLA.

Where do tiles live? In Part 4 I tracked exactly what lived in SRAM vs HBM. In JAX, there’s no control over placement. XLA decides what to keep on-chip based on the computation graph. The fori_loop structure gives it a hint: q_tile, running_max, running_sum, acc are loop-carried state, so XLA will try to keep them on-chip. But that’s trusting the compiler rather than specifying it.

q_offsets and k_offsets: these are the JAX equivalent of Part 4’s tl.arange — they create the tile index vectors used for bounds checking and masking. q_mask = q_idx < n is the same bounds check that mask = offsets < n_elements was in Triton’s vector add. And q_safe = jnp.minimum(q_idx, n - 1) is a clamped gather — it prevents out-of-bounds reads (crash prevention), while q_mask separately zeros out the garbage values from those clamped positions.

This is the fundamental tradeoff. Triton gives control, JAX gives portability. The same flash_attention_tiled function runs on TPU, GPU, or CPU with zero code changes. The cost is losing the ability to say “this tile lives in SRAM.”

Correctness check (on shapes that aren’t multiples of the block size, to test boundary logic):

n= 257, d= 64, blocks=(64,64)   match=True  max_abs=0.004399
n= 513, d= 64, blocks=(128,128) match=True  max_abs=0.003483
n= 777, d= 80, blocks=(128,64)  match=True  max_abs=0.005013

Note the max_abs is larger than on GPU — on TPU, XLA may use bf16 internally even when float32 is requested, which gives ~1e-3 precision instead of ~1e-5.

Memory scaling

Same story as Part 3: the score matrix is O(n²), the output is O(n·d). The flash version never allocates the score matrix:

 seq_len    scores (n^2)    output (n*d)       ratio    fits on-chip?
----------------------------------------------------------------------
     512           1.0 MB           0.1 MB         8.0x           yes
    1024           4.0 MB           0.2 MB        16.0x           yes
    2048          16.0 MB           0.5 MB        32.0x           yes
    4096          64.0 MB           1.0 MB        64.0x           yes
    8192         256.0 MB           2.0 MB       128.0x            NO
   16384        1024.0 MB           4.0 MB       256.0x            NO
   32768        4096.0 MB           8.0 MB       512.0x            NO

On GPU, the score matrix exceeds SM shared memory (~164 KB) at n=256. On TPU, the on-chip SRAM is ~128 MB — the score matrix fits until n=8192. That’s a 32x higher threshold before tiling becomes strictly necessary for capacity reasons. (More on TPU memory architecture later. These numbers are for a single attention head with d=64 — multi-head attention at d=128 with multiple heads sharing the on-chip memory would shift the crossover point down.)

Looks good. On to the benchmark.


Benchmark: the moment of truth

On GPU, flash attention was the whole point — it avoids materializing the n×n score matrix. On TPU with XLA, standard attention gets auto-fused. Time to find out if the tiling helps.

Setup: All benchmarks run on a Colab TPU v5e (single chip), JAX 0.7.2, float32 inputs, single-head (n, 64). Each timing is the mean of 10 iterations after 1 warmup, measured with block_until_ready() to exclude async dispatch. Compilation time is excluded — only runtime is measured.

To simulate “what if XLA didn’t fuse” (the GPU-without-Triton experience), I also benchmark an unfused version: three separate jitted functions with block_until_ready() between them, forcing each intermediate to materialize in HBM. And a nojit version where every single op is a separate kernel dispatch — maximum suffering.

# ── Unfused baseline: simulate GPU-without-Triton on TPU ──────────
# Each step is a separate jitted function. block_until_ready() forces
# each intermediate to materialize in HBM before the next step starts.

@jax.jit
def _step1_scores(q, k, scale, causal_mask):
    scores = (q @ k.T) * scale
    return jnp.where(causal_mask, -jnp.inf, scores)

@jax.jit
def _step2_softmax(scores):
    return jax.nn.softmax(scores, axis=-1)

@jax.jit
def _step3_output(weights, v):
    return weights @ v

def unfused_causal_attention(Q, K, V, causal_mask):
    """Attention with each step as a separate XLA dispatch — no fusion."""
    scale = jnp.float32(1.0 / math.sqrt(Q.shape[-1]))
    scores = _step1_scores(Q, K, scale, causal_mask)
    scores.block_until_ready()          # force HBM round-trip
    weights = _step2_softmax(scores)
    weights.block_until_ready()          # force HBM round-trip
    out = _step3_output(weights, V)
    return out


# ── Maximum suffering: no @jit, every op dispatches separately ────
def nojit_causal_attention(Q, K, V):
    """Every. Single. Op. Is. A. Separate. Kernel. Launch."""
    scale = 1.0 / math.sqrt(Q.shape[-1])
    scores = Q @ K.T                                    # dispatch 1
    scores.block_until_ready()
    scores = scores * scale                             # dispatch 2
    scores.block_until_ready()
    mask = jnp.triu(jnp.ones((Q.shape[0], Q.shape[0]), dtype=bool), k=1)
    scores = jnp.where(mask, -jnp.inf, scores)         # dispatch 3
    scores.block_until_ready()
    weights = jax.nn.softmax(scores, axis=-1)           # dispatch 4
    weights.block_until_ready()
    out = weights @ V                                   # dispatch 5
    out.block_until_ready()
    return out
Backend: tpu
     n   scores(MB)    VMEM?   nojit(ms)   unfused(ms)   fused(ms)   flash(ms)   fuse speedup
-----------------------------------------------------------------------------------------------
   512        1.0      yes       1.390         0.475       0.076       0.082          6.3x
  1024        4.0      yes       1.504         0.497       0.055       0.133          9.0x
  2048       16.0      yes       1.737         0.651       0.067       0.522          9.7x
  4096       64.0      yes       3.016         1.038       0.072       2.509         14.5x
  8192      256.0       NO       7.385         2.834       1.189      14.052          2.4x
 16384     1024.0       NO      25.576        10.110       4.445      89.567          2.3x
 32768     4096.0       NO         OOM           OOM      17.123     103.016             —

My flash attention is 35x slower than the fused standard at n=4096. Not a little worse. Catastrophically worse.

And look at the fuse speedup column — XLA’s fusion is doing something incredible. The unfused version forces three HBM round-trips (scores, weights, output). The fused version avoids all of them. At n=4096, that’s a 14.5x speedup just from fusion. The XLA compiler is earning its keep.

The nojit column is there for fun. Every single op — matmul, scale, mask, softmax, final matmul — dispatches as a separate kernel with a full HBM round-trip in between. 3ms at n=4096 vs 0.072ms fused. That’s what “no compiler optimization” looks like on a TPU.


What just happened?

Look at those numbers again. My flash attention — the algorithm that was the entire point of Parts 3 and 4 — is slower than unfused standard attention on TPU at n=4096.

My best theory: the fused standard path wins because XLA sees the entire softmax(Q @ K.T) @ V expression at once and compiles it into one optimized kernel — no intermediate matrices spilling to HBM. My flash attention uses fori_loop, which XLA likely compiles as a generic sequential loop. It probably can’t fuse across iterations, can’t pipeline memory loads, can’t interleave independent work. (I haven’t dumped the HLO to verify this — it’s an inference from the benchmark numbers and XLA’s documented behavior.)

But here’s the thing. The outer loop over Q blocks has zero data dependency between iterations. Each Q block reads the same K/V, maintains its own softmax state, writes to different output rows. The only truly sequential part is the inner K loop, where the running max and sum accumulate tile by tile.

fori_loop likely hides this parallelism from the compiler. XLA is a JIT compiler — it does dataflow analysis on the computation graph. If it could see that the Q blocks are independent, it could potentially schedule them in parallel, interleave their memory loads, maybe even dispatch them to different MXUs.

But fori_loop is opaque — it presents as “a loop with carried state.” At minimum, the compiler isn’t getting an explicit “these iterations are independent” signal from the code.

So what if I just… told XLA that the Q tiles have no dependencies on each other?


The vmap insight

jax.vmap transforms a function that processes one item into a function that processes a batch — and crucially, it tells XLA that every item in the batch is independent. No carried state between them.

Instead of two nested fori_loops, vmap replaces the outer Q loop and fori_loop stays only for the inner K accumulation (which genuinely is sequential). Same algorithm, same tiles, same math — just giving the compiler one piece of information it didn’t have before.

@partial(jax.jit, static_argnames=("block_m", "block_n"))
def flash_attention_vmap(Q, K, V, block_m=128, block_n=128):
    n, d = Q.shape
    scale = jnp.float32(1.0 / math.sqrt(d))
    num_q_blocks = math.ceil(n / block_m)
    num_k_blocks = math.ceil(n / block_n)
    n_pad = num_q_blocks * block_m

    k_all = K.astype(jnp.float32)
    v_all = V.astype(jnp.float32)
    k_offsets = jnp.arange(block_n)

    # Pad Q and reshape into (num_q_blocks, block_m, d)
    q_padded = jnp.zeros((n_pad, d), dtype=jnp.float32)
    q_padded = q_padded.at[:n, :].set(Q.astype(jnp.float32))
    q_blocks = q_padded.reshape(num_q_blocks, block_m, d)

    q_offsets = jnp.arange(block_m)
    q_starts = jnp.arange(num_q_blocks) * block_m

    # (row_max, row_sum, accumulator) — the online softmax state
    SoftmaxState = tuple[jax.Array, jax.Array, jax.Array]

    def one_q_block(q_tile: jax.Array, q_start: jax.Array) -> jax.Array:
        """Process one Q block against all K/V blocks.
        No data dependency on other Q blocks."""
        q_idx = q_start + q_offsets                                          # (block_m,)
        q_mask = q_idx < n

        running_max = jnp.full((block_m,), -jnp.inf, dtype=jnp.float32)
        running_sum = jnp.zeros((block_m,), dtype=jnp.float32)
        acc = jnp.zeros((block_m, d), dtype=jnp.float32)

        def k_body(k_block: int, state: SoftmaxState) -> SoftmaxState:
            running_max, running_sum, acc = state

            k_start = k_block * block_n
            k_idx = k_start + k_offsets                                      # (block_n,)
            k_mask = k_idx < n
            k_safe = jnp.minimum(k_idx, n - 1)              # scalar broadcasts across vector

            k_tile = jnp.where(k_mask[:, None], k_all[k_safe, :], 0.0)
            v_tile = jnp.where(k_mask[:, None], v_all[k_safe, :], 0.0)

            scores = (q_tile @ k_tile.T) * scale                            # (block_m, block_n)

            causal = q_idx[:, None] >= k_idx[None, :]
            valid = q_mask[:, None] & k_mask[None, :] & causal
            scores = jnp.where(valid, scores, -jnp.inf)

            tile_max = jnp.max(scores, axis=1)
            new_max = jnp.maximum(running_max, tile_max)

            rescale = jnp.where(
                jnp.isfinite(running_max),
                jnp.exp(running_max - new_max),
                0.0,
            )
            weights = jnp.where(
                jnp.isfinite(new_max)[:, None],
                jnp.exp(scores - new_max[:, None]),
                0.0,
            )

            running_sum = rescale * running_sum + jnp.sum(weights, axis=1)
            acc = rescale[:, None] * acc + weights @ v_tile

            return new_max, running_sum, acc

        running_max, running_sum, acc = jax.lax.fori_loop(
            0, num_k_blocks, k_body, (running_max, running_sum, acc)
        )

        out_tile = jnp.where(running_sum[:, None] > 0, acc / running_sum[:, None], 0.0)
        return out_tile

    # vmap over Q blocks — XLA sees all blocks at once, can interleave MXU/VPU/DMA
    all_tiles = jax.vmap(one_q_block)(q_blocks, q_starts)                    # (num_q_blocks, block_m, d)
    out = all_tiles.reshape(n_pad, d)
    return out[:n, :].astype(Q.dtype)

Results:

fori vs vmap match: True
max diff: 0.000000

     n    fori(ms)    vmap(ms)   fused(ms)   vmap speedup
------------------------------------------------------------
   512       0.074       0.065       0.065          1.1x
  1024       0.133       0.079       0.069          1.7x
  2048       0.525       0.083       0.069          6.3x
  4096       2.510       0.178       0.072         14.1x
  8192      14.061       0.587       1.194         23.9x
 16384      89.538       1.997       4.444         44.8x

45x faster at n=16384. Same algorithm. Same tiles. Same math. The only difference: vmap instead of fori_loop on the outer Q dimension.

And look at the fused column — in this benchmark, vmap flash attention doesn’t pull ahead until n=8192, when the score matrix is 256 MB and no longer fits in ~128 MB of VMEM. At n=4096, XLA’s fused standard path still wins comfortably. Below that threshold, the fully fused path keeps everything on-chip and wins. Above it, the tiled approach avoids materializing the score matrix entirely — exactly the same win as on GPU, just at a higher threshold because TPU has more on-chip memory.

This was the biggest “aha” moment of the whole project. The algorithm was never the problem. The compiler just couldn’t see the parallelism through fori_loop.


The practical story is done — the vmap fix works, and in this benchmark it beats fused standard attention once the score matrix outgrows VMEM. But I was left with the nagging question: why did the original fail so badly? What is the hardware actually doing with those tiles? The rest of this post is the rabbit hole I fell into trying to answer that. It shifts from experiment log to architecture explainer — feel free to stop here if the benchmark results are all that matters.


OK but seriously — what even is a TPU?

The vmap result is wild — 45x faster, and it even beats XLA’s fused attention at large sizes. Just from telling the compiler that Q blocks are independent. But I still don’t really understand why the original was so slow, or what the hardware is actually doing with those tiles. Time to look up how TPU works.

Inside a TPU chip

A TPU v5e chip (what Colab provides in the free tier) has one TensorCore — the unit that does all compute:

TPU v5e chip
└── TensorCore
    ├── 4x MXU   (128x128 systolic arrays — the matrix multiply engines)
    ├── 1x VPU   (vector processing unit — elementwise ops, reductions)
    ├── 1x Scalar unit   (control flow, instruction dispatch, DMA orchestration)
    └── ~128 MB VMEM   (shared on-chip SRAM scratchpad)

MXU: the main event

On a GPU, the SM is built around CUDA cores — scalar ALUs, 32 of which execute in lockstep as a warp (Part 4 covered this). Tensor cores are a separate thing — specialized matrix multiply units bolted onto each SM. They accelerate matmul, but the SM’s general-purpose work still runs on CUDA cores. Tensor cores are an accelerator, not the foundation.

A TPU flips this. The MXU (Matrix Multiply Unit) isn’t a bolt-on accelerator — it IS the primary compute engine. Each MXU is a 128x128 systolic array: 16,384 multiply-accumulate cells. The v5e has 4 MXUs per chip, all fed from the same VMEM. Everything that can be expressed as a matrix multiply goes through the MXUs.

“Systolic” means data flows through the array rhythmically, like a heartbeat. One matrix is pre-loaded into the cells and stays stationary. The other streams in from the left, flowing through each cell. Every cell multiplies its resident weight by the passing activation, accumulates the partial sum, and hands data to its neighbor. By the time data exits the bottom, that’s a full matrix multiply — and no intermediate values touched memory.

VPU: not CUDA cores

The VPU (Vector Processing Unit) handles everything that isn’t a matmul: elementwise ops (ReLU, exp, add), reductions, type casts. It’s a wide SIMD vector unit — think AVX-512 on steroids, not thousands of CUDA cores.

There’s only one VPU shared across the whole chip, and it’s roughly 10x slower than the MXUs for the same FLOP count. This is why on TPU, expressing as much computation as matmul as possible matters — everything else is a relative bottleneck.

No threads

This is the biggest shift from GPU thinking.

On a GPU, memory latency is hidden by thread parallelism — when one warp stalls on a memory read, the SM switches to another (Part 4 covered this). A TPU has no threads. The scalar unit dispatches instructions to the MXUs and VPU. Latency hiding comes from pipelining: while the MXUs compute one tile, the DMA engine prefetches the next tile from HBM into VMEM. Same idea, completely different mechanism.

  GPU (A100) TPU (v5e)
Chip structure 108 SMs, each independent 1 TensorCore per chip
Matrix units 4 tensor cores per SM (432 total) 4 MXUs (128x128 systolic arrays)
Scalar/vector compute CUDA cores (thousands of scalar ALUs) 1 VPU (wide SIMD vector unit)
Execution model Thousands of threads, warp switching Single-threaded, pipelined
Latency hiding More warps ready to go Overlap DMA with compute
On-chip SRAM ~164 KB shared memory per SM ~128 MB VMEM per chip (shared)

Memory hierarchy

Same structure as GPU — fast on-chip, slow off-chip — but the sizes are very different:

VMEM        ~128 MB / chip   (on-chip SRAM — shared by all 4 MXUs + VPU)
HBM         16 GB            ~820 GB/s   (off-chip — same role as GPU HBM)

An A100 SM has ~164 KB of shared memory. A TPU v5e has ~128 MB of VMEM — roughly 800x more on-chip space. Bigger tiles fit on-chip, more data reuse per HBM load. Same tiling tradeoff from Part 4 — bigger tiles = more reuse but must fit in SRAM — just with a much higher ceiling on TPU.


How data flows through a systolic array

I kept seeing the phrase “systolic array” and thinking I understood it. I did not. Let me draw it out.

Systolic array overview — 4x4 array with weights pre-loaded, cell detail, and stagger diagram

Weight-stationary (what the TPU MXU uses)

The key idea: weights stay put, everything else flows.

For C = A @ B where A is (M, K) and B is (K, N):

  • The array is K rows x N columns (matching B’s dimensions)
  • Cell (k, n) holds B[k][n] — loaded once, never moves
  • Activations from A stream in from the left, one element per cell per cycle
  • Partial sums flow downward through each column
  • Result C[m][n] exits from the bottom of column n
         col 0     col 1
         +-----+   +-----+
A[m,0] > |B[0,0]| > |B[0,1]|   < row 0 (activation passes right)
         +--+--+   +--+--+
            | S       | S        < partial sums flow down
         +--+--+   +--+--+
A[m,1] > |B[1,0]| > |B[1,1]|   < row 1
         +--+--+   +--+--+
            | S       | S
         +--+--+   +--+--+
A[m,2] > |B[2,0]| > |B[2,1]|   < row 2
         +--+--+   +--+--+
            |          |
         C[m,0]     C[m,1]       < results exit bottom

Why weight-stationary? In neural network inference, the same weights multiply many different input batches. Loading weights once and streaming activations through means the most expensive data (weights — large, reused) never moves.

The stagger

Here’s the part I had to stare at. A[m][k] doesn’t enter row k at the same time as A[m][0] enters row 0. It’s staggered: A[m][k] enters row k delayed by k cycles. Why? Because partial sums flow downward — cell (k, n) needs to receive both:

  1. The activation A[m][k] from the left
  2. The partial sum from cell (k-1, n) above — which takes k cycles to get there (flowing down from row 0)

The stagger synchronizes these two data flows. Without it, the activation would arrive before its matching partial sum, or vice versa.

Here’s the timing for a (2, 3) @ (3, 2) matmul:

Cycle:    0          1          2          3
       +------+  +------+  +------+  +------+
Row 0: |A[0,0]|  |A[1,0]|  |      |  |      |
       +------+  +------+  +------+  +------+
Row 1: |      |  |A[0,1]|  |A[1,1]|  |      |    < delayed by 1
       +------+  +------+  +------+  +------+
Row 2: |      |  |      |  |A[0,2]|  |A[1,2]|    < delayed by 2
       +------+  +------+  +------+  +------+

Each new row of A (m=0, m=1) only costs 1 extra cycle. The pipeline is always full — no bubbles between different rows of A within one matmul. Total cycles: M + K + N - 2.

Output-stationary (not the TPU, but it shows up in diagrams)

Searching for systolic array diagrams will often turn up a different design where both A and B stream in — A from the left, B from the top. This is the output-stationary design:

  • The array is M rows x N columns (matching C’s dimensions)
  • Cell (i, j) accumulates C[i][j] — the result builds up in place
  • Both inputs flow through and keep moving

This is the design that shows “both matrices streaming from two sides.” It’s a valid design, but it’s not what the TPU uses. The TPU uses weight-stationary because it minimizes the most expensive data movement for inference workloads.


Building a systolic array emulator

To really understand the timing, I built a tick-based emulator. A SystolicArray class with a tick() method that advances one cycle, moving data through the pipeline exactly as the hardware would.

class SystolicArray:
    """Fixed-size weight-stationary systolic array emulator (TPU MXU design).

    Dimensions:
        - The array has `num_rows` rows and `num_cols` columns of cells.
        - B (num_rows x num_cols) is pre-loaded into cells — one weight per cell, stationary.
        - A (num_activations x num_rows) streams in from the left, one row of A per cycle,
          staggered: A[m, row] enters at cycle (m + row).
        - Partial sums flow downward through rows. Result C[m, col] exits
          the bottom of column `col` at cycle (m + num_rows + col - 1).
    """

    def __init__(self, num_rows: int, num_cols: int):
        self.num_rows = num_rows      # K: inner dimension of the matmul
        self.num_cols = num_cols      # N: number of output columns

        self.weights = np.zeros((num_rows, num_cols))
        # NaN means the cell is idle (no activation has arrived yet)
        self.activation_in_cell = np.full((num_rows, num_cols), np.nan)
        # Row 0 starts at 0; each row adds its contribution and passes down
        self.partial_sum = np.zeros((num_rows, num_cols))

        self.cycle = 0
        self._A = None
        self._num_activations = 0
        self._total_cycles = 0
        self._done = False
        self.results = {}             # (m, col) -> final dot product value

    def load_weights(self, B):
        """Pre-load weight matrix B into the array. One weight per cell, stays fixed."""
        assert B.shape == (self.num_rows, self.num_cols)
        self.weights = B.astype(np.float64).copy()

    def start_matmul(self, A):
        """Queue activation matrix A for streaming. Resets all pipeline state."""
        num_activations, inner_dim = A.shape
        assert inner_dim == self.num_rows
        self._A = A.astype(np.float64).copy()
        self._num_activations = num_activations
        self._total_cycles = num_activations + self.num_rows + self.num_cols - 2
        self._done = False
        self.cycle = 0
        self.results = {}
        self.activation_in_cell = np.full((self.num_rows, self.num_cols), np.nan)
        self.partial_sum = np.zeros((self.num_rows, self.num_cols))

    def tick(self):
        """Advance the array by one cycle."""
        current_cycle = self.cycle
        new_activation_in_cell = np.full((self.num_rows, self.num_cols), np.nan)
        new_partial_sum = np.zeros((self.num_rows, self.num_cols))

        for row in range(self.num_rows):
            for col in range(self.num_cols):

                # Step 1: Where does this cell's activation come from?
                if col == 0:
                    # First column: from the input queue.
                    # A[m, row] enters at cycle t = m + row (the stagger).
                    activation_index = current_cycle - row
                    if 0 <= activation_index < self._num_activations:
                        activation = float(self._A[activation_index, row])
                    else:
                        activation = None     # ramp-up or drain phase
                else:
                    # Other columns: passes rightward from the left neighbor.
                    left_neighbor = self.activation_in_cell[row, col - 1]
                    if np.isnan(left_neighbor):
                        activation = None     # left neighbor was idle
                    else:
                        activation = float(left_neighbor)

                # Step 2: Partial sum from above
                if row == 0:
                    incoming_partial_sum = 0.0   # top row starts at zero
                else:
                    incoming_partial_sum = float(self.partial_sum[row - 1, col])

                # Step 3: Compute if we have an activation
                if activation is not None:
                    weight = float(self.weights[row, col])
                    updated_partial_sum = incoming_partial_sum + activation * weight

                    new_activation_in_cell[row, col] = activation
                    new_partial_sum[row, col] = updated_partial_sum

                    # Bottom row: result exits the array
                    if row == self.num_rows - 1:
                        result_index = current_cycle - row - col
                        if 0 <= result_index < self._num_activations:
                            self.results[(result_index, col)] = updated_partial_sum
                else:
                    new_partial_sum[row, col] = incoming_partial_sum

        self.activation_in_cell = new_activation_in_cell
        self.partial_sum = new_partial_sum
        self.cycle += 1
        if self.cycle > self._total_cycles:
            self._done = True

    @property
    def done(self):
        return self._done

    def matmul(self, A, B):
        """Load weights, stream A, tick until done, return result matrix."""
        self.load_weights(B)
        self.start_matmul(A)
        while not self.done:
            self.tick()
        C = np.zeros((self._num_activations, self.num_cols))
        for (m, col), value in self.results.items():
            C[m, col] = value
        return C

Quick test:

A @ B = [[ 4.  5.]
 [10. 11.]]
Emulator = [[ 4.  5.]
 [10. 11.]]
Match: True
Total cycles: 6  (M+K+N-2+1 = 6)

Cycle-by-cycle systolic array execution — (2,3) @ (3,2) over 6 cycles

The key insight from building this: the stagger isn’t a complication, it’s the mechanism. By delaying A[m, k]’s entry into row k by exactly k cycles, the activation arrives at each cell at the same moment as the matching partial sum from above. The pipeline stays full, no control logic needed. It’s elegant.

I wired the emulator into a TPUCycleSimulator that counts MXU and VPU cycles for the full attention computation — both flash and standard. For small matrices (all dimensions ≤ 16), it ticks through the actual systolic array and verifies the cycle count matches the M + K + N - 2 formula. For larger matrices, it uses the formula directly.

class TPUCycleSimulator:
    """Approximate cycle-level simulation of TPU MXU + VPU.
    Uses the SystolicArray emulator for matmuls — the cycle count
    falls out of the hardware simulation rather than a formula.
    """

    def __init__(self, mxu_dim=128, vpu_width=128):
        self.mxu_dim = mxu_dim
        self.vpu_width = vpu_width
        self.mxu_cycles = 0
        self.vpu_cycles = 0
        self.mxu_flops = 0

    def matmul(self, A, B):
        """Route through the systolic array emulator.
        For tiles that fit (K,N <= 16), tick through actual hardware pipeline.
        The cycle count M+K+N-2 isn't assumed — it's verified.
        """
        M, K = A.shape
        _, N = B.shape
        formula_cycles = M + K + N - 2

        if K <= 16 and N <= 16 and M <= 16:
            arr = SystolicArray(num_rows=K, num_cols=N)
            C = arr.matmul(A, B)
            assert arr.cycle == formula_cycles + 1
        else:
            C = A @ B

        self.mxu_cycles += formula_cycles
        self.mxu_flops += 2 * M * K * N
        return C

    def vpu(self, n_elements, cycles_per_vec=1):
        """Elementwise VPU op. 128 elements per vector.
        Ceiling division: (n-1)//128+1 so exact multiples don't overshoot."""
        self.vpu_cycles += ((n_elements - 1) // self.vpu_width + 1) * cycles_per_vec
Systolic array cycle counts verified against formula ✓

What the emulator revealed

The simulator compares flash attention (block=128) against standard attention for n=512, d=64:

                                    block=64   block=128    standard
  ─────────────────────────────────────────────────────────────────
  Total cycles                        24,556      16,936      20,604
  MXU cycles                          13,680       6,360       2,172
  VPU cycles                          10,876      10,576      18,432
  MXU utilization                       8.4%       20.1%       94.3%
  vs standard                          1.19x       0.82x       1.00x

Flash does less total compute for causal attention. It skips entire tiles in the upper triangle — 6 tiles out of 16 for a 4×4 grid. Standard attention processes the full n×n matrix, running exp(-inf) on all the masked entries. Flash never touches them at all.

But MXU utilization tells the real story. Even with block=128, flash attention’s MXU utilization is only ~20% vs standard’s ~94%. Flash has two matmuls per tile: Q_tile @ K_tile.T = (128, 64) @ (64, 128) and weights @ V_tile = (128, 128) @ (128, 64). Both have inner dimension ≤ d=64 or block=128, so the systolic pipeline runs for at most 128 steps through a 128-wide array. Standard attention’s weights @ V is (512, 512) @ (512, 64) — the inner dimension is 512, giving the pipeline 512 steps of useful work. That single large matmul is what drives standard’s ~94% utilization.

The simulator likely overcounts standard attention though. A fused XLA kernel could, in principle, recognize the causal mask and skip the upper triangle entirely — never compute exp(-inf), never multiply by zero weights. The simulator charges full price for the masked entries; a smart compiler probably wouldn’t. (Without profiling the actual XLA-generated code, this is speculation — but the benchmark gap is consistent with it.)

The sharpest version of the insight: The algorithm does less compute than standard attention. vmap proves it — once XLA can see the Q-block parallelism, it gets within 2x of the fused path and beats it at large sizes. The remaining gap is likely DMA pipelining and fusion — things only a lower-level API can express. (Dumping the HLO would confirm this; for now it’s an educated guess from the benchmark shape.)

What production code does

jax.nn.dot_product_attention is JAX’s built-in attention. XLA recognizes the pattern and applies its own optimized implementation:

@jax.jit
def builtin_causal_attention(Q, K, V):
    # Expects (batch..., seq, heads, head_dim) — NOT (seq, d).
    # Add heads=1 dimension: (n, d) -> (n, 1, d) -> call -> squeeze back.
    out = jax.nn.dot_product_attention(
        Q[:, None, :], K[:, None, :], V[:, None, :],
        is_causal=True,
    )
    return out[:, 0, :]

The benchmark confirmed it — identical performance to fused standard attention at every size:

     n   scores(MB)    VMEM?    standard(ms)   flash(ms)   builtin(ms)   builtin speedup
-------------------------------------------------------------------------------------
   512        1.0      yes         0.070       0.070         0.067            1.05x
  1024        4.0      yes         0.066       0.133         0.079            0.85x
  2048       16.0      yes         0.073       0.521         0.081            0.91x
  4096       64.0      yes         0.073       2.507         0.074            0.99x
  8192      256.0       NO         1.188      14.051         1.189            1.00x
 16384     1024.0       NO         4.444      89.542         4.448            1.00x
 32768     4096.0       NO        17.115     102.995        17.222            0.99x

For anything beyond what XLA auto-selects, there’s Splash Attention — Google’s TPU-optimized flash attention written in Pallas. It uses DMA pipelining, MXU-matched tile sizes, and 2D grid scheduling — everything my fori_loop couldn’t express.


Pallas: what it would take to beat the compiler

So how does Splash Attention actually beat XLA’s fused path? Pallas — JAX’s equivalent of Triton. Write custom kernels in Python that lower through Mosaic to TPU VLIW instructions.

The three things Pallas provides that pure JAX can’t express:

  1. DMA pipelining. The fori_loop implementation likely does load-wait-compute-load-wait-compute. A Pallas kernel can double-buffer: while the MXU computes on the current tile, the DMA engine fetches the next tile into a separate VMEM buffer. Compute and memory transfer overlap instead of serializing.

  2. MXU-matched tiling with causal skipping. A 2D Pallas grid (num_q_blocks, num_kv_blocks) gives Mosaic full visibility into the iteration pattern. It knows which tiles are fully masked by the causal triangle and skips them entirely — no wasted MXU cycles.

  3. Explicit VMEM placement. All data movement goes through BlockSpec declarations — no dynamic indexing in the kernel body. This is how the hardware knows what to prefetch.

I tried writing one. Mosaic’s constraints are restrictive — no dynamic indexing (k_all[indices, :] lowers to an unsupported gather), 1D blocks must be multiples of 128, kernels that compile on one JAX version fail on another. The code didn’t survive into this post. There’s a reason Splash Attention is a serious engineering effort, not a code snippet.

At this point my brain was pretty thoroughly consumed by the TPU architecture rabbit hole. The Pallas deep dive can wait for another day.

Approach When to use
jax.nn.dot_product_attention Default. XLA picks the best strategy.
Splash Attention (Pallas) Long sequences at scale, kernel-level tuning beyond XLA.
Pure JAX fori_loop (what I wrote) Understanding the algorithm. Not for production.

What I actually learned

The hardware was already doing it

The whole arc of this post is: I tried to force a GPU optimization onto a TPU, and — for this setup (single head, d=64, Colab v5e) — the TPU was already handling it natively.

Flash attention exists because GPU SRAM is tiny (~164 KB/SM) — the n×n score matrix never fits, so tiling in software is mandatory. On TPU, the MXU is literally a tile processor. A 128x128 systolic array that holds one matrix stationary and streams the other through — that’s what flash attention implements in software on GPU, but it’s what the TPU hardware does by default.

Add ~128 MB of VMEM (800x more on-chip memory than a GPU SM), and XLA’s automatic fusion, and the score matrix just… stays on-chip. My handwritten tiling was reimplementing what the hardware and compiler already handle, but worse. (At production scale — multi-head, longer sequences, larger d — the tradeoffs shift and Splash Attention becomes necessary. But for the single-head setup I was benchmarking, the compiler had it covered.)

Giving the compiler information matters more than writing clever code

The 45x speedup from fori_loop to vmap wasn’t a better algorithm. It was the same algorithm with one additional piece of information: “these Q blocks are independent.” XLA is a JIT compiler — it does dataflow analysis, operator fusion, memory planning. But it can’t infer independence from a fori_loop with carried state. vmap is semantically “map this function over a batch” — independence is built into the abstraction.

This is a different skill than writing Triton kernels. In Triton, the programmer is the compiler — deciding what goes where. In JAX, it’s a conversation with a compiler. The better the intent is expressed, the better code it generates. fori_loop said “do these sequentially.” vmap said “these are independent.” Same math. 45x difference.

Tiling is the same idea everywhere — it’s just a question of who does it

  TPU GPU
Tile-level matmul hardware (MXU is a 128x128 tile) software (tensor cores need warp-level MMA instructions)
Tiling schedule compiler (XLA) programmer (Triton/CUDA) or compiler (torch.compile)
On-chip memory management compiler (VMEM) programmer (shared memory)

Same building block: tile, stream, accumulate. TPU pushes more into hardware and compiler. GPU gives more control but requires more work. The end result is the same math at the same scale.

The comparison table

  Triton / GPU (Part 4) JAX / TPU (this post)
Compiler Triton -> LLVM -> PTX JAX -> HLO -> XLA -> device code
Fusion I fuse manually (the kernel IS the fusion) XLA fuses automatically
Tiling Manual pointer arithmetic Implicit (compiler decides) or BlockSpec (Pallas)
Memory control I decide what lives in SRAM Compiler decides what lives in VMEM
On-chip SRAM ~164 KB / SM ~128 MB / chip
Why flash attention wins SRAM is tiny -> score matrix NEVER fits VMEM is huge -> score matrix fits until ~n=8K (single head, d=64)

The biggest lesson: the same optimization has completely different value on different hardware. I spent Parts 3-4 building up flash attention as this essential technique — and it is, on GPU. On TPU — at least for this single-head, d=64 setup on a Colab v5e — the hardware architecture makes it unnecessary for typical sequence lengths, and the compiler handles it when it does become necessary. Understanding why I lost taught me more about both architectures than winning on GPU did.

联系我们 contact @ memedata.com