Writing · May 2026 · ~8 min read

Most attention is wasted.

Take a transformer attending over 2048 tokens. It computes 4 million pairwise interactions per layer per head. Run a softmax over each row, look at the resulting weights, and the picture is the same almost every time: a sharp spike on a few neighbouring positions, a thin tail across a few faraway ones, and zero or near-zero on everything else.

The model paid quadratic compute for that row. The information density is closer to linear.

This is the trade-off that an entire research line has been chasing — Longformer, BigBird, Performer, Linformer, Perceiver, Compressive Transformer, Memorizing Transformer, Mamba. They all replace the dense N×N pattern with something cheaper but expressive enough to keep the model strong. I wanted a small, working reference implementation of one specific recipe: a sliding window for local detail plus a learned latent memory bank for long-range routing. So I wrote one.

It is called MSA-lite. About 200 lines of attention code, 18 tests, a measured 10x speedup over nn.MultiheadAttention at 8192 tokens on CPU. The interesting part is not that the speedup exists — the math says it has to. The interesting part is that you can poke at the resulting layer in real time and see the architecture make sense.

Try it

The whole layer is wrapped in a Streamlit dashboard, hosted live below. Move the sliders, retrain on the selective-copy task, watch the attention pattern evolve. No setup; the demo runs server-side on Hugging Face Spaces.

MSA-lite · live demo
HF Spaces · first load takes ~30s
Doesn’t load? Open it directly → abhishekshekhar-msa-lite-demo.hf.space

The trade, in mechanical terms

Standard attention asks every token to look at every other token. The cost is O(N²) and most of it is wasted, as established above. MSA-lite replaces that single dense path with two cheap paths in parallel.

The local path chunks the sequence into blocks of width W. Each block attends to itself plus one block to the left and one to the right, giving each token an effective window of about W on either side. The key trick is that the score matrix is never materialised at full size — the largest tensor in the local path is W×3W per chunk. Cost: O(N·W).

The memory path introduces a small bank of K learned vectors. Every input token cross-attends into this bank to read whatever long-range information the model has decided is worth caching. K is fixed and small (16 to 64 in practice), so this path costs O(N·K). The slots are real parameters, trained end-to-end, and they get refreshed each forward step from a compressive read of the input. Practically: they are the model’s working scratchpad.

A learned gate, a sigmoid over a per-token linear projection, decides how much each token routes through local versus memory. Untrained, it sits near 0.5 because both paths are random. Trained on a long-range task, it specialises — certain positions clearly want the memory bank, others clearly want their neighbours.

Total per-layer cost: O(N·(W + K)). Linear in N. At sequence length 8192 with W=64 and K=32, that is 786,432 operations versus 67 million for full attention. In wall-clock terms, 34 ms versus 336 ms.

Why this combination, specifically

Sliding-window-only attention (Longformer’s recipe minus the global tokens) is good at local syntax and bad at long-range routing. Memory-only attention (Perceiver-style) is good at routing and lossy at local detail because every token has to round-trip through K slots. Combining them is the obvious move and the prior work bears it out, but I wanted to feel the trade-off myself.

The thing that becomes elegant once you write it is that the memory bank is shared across the whole batch and refreshed inline each forward pass. The refresh is a small EMA over a non-parametric attention read, which keeps autograd clean and prevents the slots from being clobbered batch by batch. So at every training step the slots get nudged toward whatever pattern is currently useful, the same way a working scratchpad gets updated as you read.

The chunked local attention, mechanically

This is the part that took the most care to get right.

The naive implementation builds the N×N score matrix, masks it to a band of width 2W+1, softmaxes, multiplies by V. That works but defeats the purpose — you have already paid the O(N²) memory cost.

The real implementation reshapes Q, K, V so the N axis becomes (chunks × W). For each chunk, the keys are built as [prev_chunk | current | next_chunk], a 3W-wide tensor. Then you do a small W×3W attention per chunk in parallel, and never materialise anything bigger than 3W·W. The boundary chunks at sequence start and end need a mask to zero out the keys that don’t exist (no chunk before chunk 0, no chunk after the last one).

Three attempts to get the mask right. The first version had a subtle bug where chunks at the start of the sequence were attending to padding, which softmax was then giving nonzero weight to. The fix was to mask before softmax with -inf, not after with zero. Standard footgun. The test suite catches it now.

The dashboard, and why it exists

Attention is hard to debug because it lives inside (B, H, N, N) tensors that no one looks at. You write the code, the tests pass, the benchmark shows a speedup, and you have basically no insight into whether the model is actually behaving the way you think.

So I added an inspection path. Calling attn(x, return_weights=True) returns the layer output plus an MSAWeights namedtuple with three things: the reconstructed banded local-attention matrix, the N×K memory cross-attention matrix, and the per-token gate values. The reconstruction is opt-in — only paid for when requested — so the default forward path is unchanged.

Then I wrapped it in the Streamlit dashboard you can drive at the top of this page. Four tabs:

Attention patterns. Run a forward pass on random input and show both attention heatmaps side by side. The local-attention map should look like a bright diagonal stripe of width ~W, dark elsewhere — that is the proof the sparsity is real, not a claim. The memory-attention map shows which slots each token reads from. Untrained, it’s near-uniform noise. Trained, certain slots take over for certain regions.

Live training. Train a tiny MSA Transformer on the selective-copy task — long random sequence, a few mark tokens, model has to copy whatever value follows each mark. It is the standard probe for long-range information routing in this literature. You watch the loss drop and accuracy climb in real time, then flip back and see how the attention patterns and the gate have specialised.

Speed benchmark. Sweep sequence lengths from 256 to 4096 and plot MSA against vanilla nn.MultiheadAttention. The two curves cross around 512 tokens. Below that, full attention’s tighter inner loop wins. Above, MSA pulls away.

Memory bank. The K×d_model parameter heatmap, plus per-slot L2 norms. Flat norms mean every slot is being used roughly equally; uneven norms mean the model has decided some are dead weight. Useful diagnostic.

The dashboard turned out to be the most useful part of the project. Twice now it has caught a bug that the test suite missed — once where a stale gate was clamped to 1.0 across the whole sequence (output looked correct, but the model had silently stopped using the local path), once where a chunk-boundary mask off-by-one made the leftmost token attend to padding.

Honest numbers

Measured on CPU, PyTorch 2.2, one layer, d_model=128, 4 heads, W=64, K=32:

At short sequences the chunking and memory-attention overhead dominates and full attention wins. Crossover near 512. The asymptotic argument starts being visible around 2048 and is dramatic by 4096. None of this is surprising, but it is satisfying to see it land exactly where the back-of-envelope math said it would.

On the selective-copy task at sequence length 512 with 8 marks, an MSA Transformer with 3 layers and 32 memory slots reaches 96% non-blank accuracy in 300 training steps on CPU. That is not a paper result, but it is the thing the architecture has to do for the rest of the claim to mean anything.

What I’d do next

Causal masking. The current implementation is bidirectional. For language modeling I would need to make the local window strictly leftward and confirm the memory cross-attention doesn’t leak future-token information through the slot refresh.

Compare directly to Mamba. State-space models claim the same linear scaling with arguably more elegance. I want to run both on the same long-context probe and see what the actual difference looks like.

GPU benchmarks. Everything here is CPU. The memory layout I chose for the chunked local attention should be friendly to GPU tiling, but I want to verify rather than claim.

Pretrain something tiny. A 6-layer 128-dim model on a small corpus, a few hundred million tokens. The selective-copy task only proves the architecture can route information; pretraining proves it can model language.

The bigger argument

The dense self-attention layer is a beautiful object — it is the thing that broke RNNs out of the sequential rut and made transformers possible. But it is also paying for compute it does not use. Most of the entries in a softmaxed attention row are essentially zero. The model has been told it can look anywhere, and at training time it figures out it shouldn’t bother with most places.

Local plus latent memory is one way of saying: only do the compute the model actually uses. The local path covers the fact that most useful interactions are short-range. The memory path covers the long-range ones in a budget-bounded way. The gate decides which path matters per token. Linear cost, no asymptotic catastrophe, end-to-end trainable.

The repo is on GitHub at abhi183/MSA-lite. About 700 lines of Python including tests and the dashboard. To run the dashboard yourself: pip install -e ".[viz]", then python -m msa.viz. Or just scroll up and use the embedded one.