How many FLOPs is attention? When does a KV cache hurt performance? Why does beam search cause your KV cache memory to explode?
With Claude’s help, I wrote ~30 interview questions on LLM Inference, focusing on the systems side of things like FLOPs, attention, and the kv-cache.
This material is not often taught in universities because it is somewhat cutting edge. But you need to know this stuff to land internships at good AI labs and also, more importantly, to really understand how your GPU works and build fast AI systems.
My notes on these topics can be found at the bottom of this page.
A @ b
where A.shape = (m, k)
and b.shape = (k, 1)
? What about A @ B
where A.shape = (m, k)
and B.shape = (k, n)
? Where does the 2 come from?A @ b
takes 2 x m x k
FLOPs:
k + (k-1)
because it is a dot product rows which requires k
multiplications and k-1
additions.m
entries in the resultant vector so we have a total of m x (k + (k-1)) ≈ 2 x m x k
FLOPs.A @ B
takes 2 x m x k x n
FLOPS:
k + (k-1)
because it is a dot product which requires k
multiplications and k-1
additions.m x n
entries in the resultant matrix so we have a total of m x n x (k + (k-1)) ≈ 2 x m x n x k
FLOPs.The 2
comes from the fact that we must perform a multiplication and an addition for each element.
What is a KV cache? Why do we store only the key and value in the KV cache and not the query or the softmax output?
Will using a KV cache help you if you are memory bound or compute bound?
Given B=4
sequences all of length L=64
, how many FLOPs does it take to perform a single forward-pass of single-headed attention with d_model=128
and n_layers=9
? Assume the model does NOT have a KV cache and that computing the softmax on a matrix with shape (m, n, n)
costs O(mn^2) = c x m x n^2
where c
is a constant that we’ll set to c=1
.
Given input X.shape = (B, L, d_model)
, weights W_Q.shape = W_k.shape = W_v.shape = (d_model, d_model)
, the forward pass of single-headed attention in a single layer is given by A(X) = softmax(Q K^T / \sqrt(d_model)) V
where Q = X W_q
, K = X W_k
, V = X W_v
.
Let’s break down the FLOPs from each operation:
Q = X W_q
has shapes (B, L, d_model) @ (d_model, d_model) -> (B, L, d_model)
and costs 2 x B × L × d_model^2
FLOPsK = X W_k
has shapes (B, L, d_model) @ (d_model, d_model) -> (B, L, d_model)
and costs 2 x B × L × d_model^2
FLOPsV = X W_v
has shapes (B, L, d_model) @ (d_model, d_model) -> (B, L, d_model)
and costs 2 x B × L × d_model^2
FLOPsR = Q K^T
has shapes (B, L, d_model) @ (B, d_model, L) -> (B, L, L)
and costs 2 x B x L^2 x d_model
FLOPsscores = softmax(R / \sqrt(d_model))
has shapes (B, L, L) -> (B, L, L)
and costs (c+1) x B x L^2
FLOPsA = scores @ V
has shapes (B, L, L) @ (B, L, d_model) -> (B, L, d_model)
and costs 2 x B x L^2 x d_model
FLOPsSo the forward pass of one layer of single-headed attention requires 3 x (2 x B × L × d_model^2) + 2 x (2 x B x L^2 x d_model) + (c+1) x B x L^2
FLOPs. Across all n_layers
, it costs n_layers x [ 3 x (2 x B × L × d_model^2) + 2 x (2 x B x L^2 x d_model) + (c+1) x B x L^2 ]
FLOPs.
Plugging in our values, we get 9 x [ 3 x (2 x 4 x 64 x 128^2) + 2 x (2 x 4 x 64^2 x 128) + (1+1) x 4 x 64^2 ]= 302,284,800 ≈ 3e+8
FLOPs.
You have the same model as before but now it has a KV cache. How many FLOPs do you save by using the KV cache?
Your model has n_layers=32
, n_heads=32
, d_head=128
and uses bfloat16 precision. How much storage does a KV cache require for a single token? What about if we have B*L
tokens where B
is the batch size and L
is the sequence length?
n_layers=32
, n_heads=32
, d_head=128
, and uses bfloat16 precision and a KV cache. At what sequence length does KV cache exceed 1GB bytes for batch_size=1?The kv cache has shape (2, B, L, n_layers, n_heads, d_head)
which takes up M=n_bytes*2*B*L*n_layers*n_heads*d_head
bytes. Here n_bytes = 16 bits/8 = 2
bytes and M=1GB=1e+9
bytes. Therefore, 1e+9 = 2x1×L×2×32×32×128 = 524,288×L ≈ 5e+6 L bytes. L ≈ 1e+9/5e+6 ≈ 2,000 tokens. After 2,000 tokens, the kv cache will exceed 1GB. This is why long conversations quickly exhaust GPU memory.
n_layers=32
, n_heads=32
, d_head=128
and uses bfloat16 precision and a KV cache. How many users can perform inference simultaneously with 500-token sequences vs 2000-token sequences?The kv cache has shape (2, B, L, n_layers, n_heads, d_head)
which takes up M=n_bytes*2*B*L*n_layers*n_heads*d_head
bytes. Notice 24GB = 2.4e+10
.
L=500
: 2.4e+10 = 2 x 2 x B x 500 x 32 x 32 x 128
. B = 24GB/(524KB x 500) ≈ 91
. At most you can have B=91
different users asking this chatbot questions simultaneously.
L=2000
: 2.4e+10 = 2 x 2 x B x 2000 x 32 x 32 x 128
. B = 24GB/(524KB x 2000) ≈ 22
. At most you can have B=22
different users asking this chatbot questions simultaneously.
Memory bandwidth is the bottleneck, not compute. GPU compute units are idle 60% of the time waiting for data. We know it’s bandwidth (not a slow computation) because reading cached data should be fast - if it takes 60% of time, the memory system can’t supply data fast enough for the compute units.
When generating the l
-th token (1≤l≤L
), the KV cache stores the l-1
previous keys and values, has shape (2, B, l-1, n_layers, n_heads, d_head)
, and takes up M(l) = n_bytes × 2 × B × (l-1) × n_layers × n_heads × d_head
bytes.
Memory Read: When generating the l
-th token, we must read the entire KV cache containing all l-1
previous tokens, requiring M(l)
bytes. The cumulative memory read across ALL L tokens is:
R(L) = ∑(l=1 to L) M(l)
= n_bytes × 2 × B × n_layers × n_heads × d_head × ∑(l=1 to L) (l-1)
= n_bytes × 2 × B × n_layers × n_heads × d_head × L×(L-1)/2
(Here, we used the identity that \sum_{i=0}^n i = (i+1)*i/2
where i=l-1
.)
Memory Write: When generating the l
-th token, we write only the new key-value pair for that token: n_bytes × 2 × B × n_layers × n_heads × d_head
bytes (constant per token). The cumulative memory written across ALL L
tokens is:
W(L) = ∑(l=1 to L) (n_bytes × 2 × B × n_layers × n_heads × d_head)
= L × (n_bytes × 2 × B × n_layers × n_heads × d_head)
L=1000
tokens, how many more/fewer times do we cumulatively read than cumulatively write? What are the implications of this?From the previous question we know that the cumulative reads from the KV cache is R(L) = n_bytes × 2 × B × n_layers × n_heads × d_head × L×(L-1)/2
and the cumulative writes to the KV cache is W(L) = L × n_bytes × 2 × B × n_layers × n_heads × d_head
. Notice that memory reads grows quadratically with L
but memory writes grow linearly with L
. So the cumulative reads grow faster than the cumulative writes.
The read-write ratio is R/W = (L-1)/2
meaning we read (L-1)/2
times more data than we write. For L=1000
tokens, we read (1000-1)/2~500
times more data than we write.
Implications: If our model becomes memory-bound, this is likely from reading from memory not from writing to memory.
You’re serving a chatbot. User conversations average 50 tokens but 1% go to 10,000 tokens. How do you handle KV cache memory?
Implement sliding window attention (keep only last N tokens) or progressive offloading (move old cache entries to CPU/disk). The 1% tail drives 99% of memory costs, so optimize for the outliers.
Production system: 95th percentile latency matters more than throughput. KV cache or large batches? KV cache for consistent low latency per user, even if total throughput (users/sec) is lower. Large batches increase individual request latency due to queuing effects.
Two GPUs: 16GB with 2TB/s bandwidth vs 32GB with 1TB/s bandwidth. Which is better for long conversations? Depends on sequence length. Short sequences (memory bandwidth bound): 16GB GPU wins. Long sequences (memory capacity bound): 32GB GPU wins. The crossover point depends on your specific workload.
Inference cost optimization: KV cache uses 2× memory but 3× faster generation. At what utilization rate do you break even? If you can keep GPUs 66%+ utilized with KV cache (⅔ of capacity), the 3× speed improvement compensates for 2× memory cost. Below 66% utilization, large batches without KV cache may be more cost-effective.
User requests 5 different continuations of same prompt. KV cache strategy? Cache the common prompt prefix once, then branch KV cache only for the different continuations. This avoids recomputing the shared 90% of work while only duplicating the divergent portions.
Why might a model perform WORSE with KV cache enabled in some scenarios? Memory pressure causing GPU memory swapping to CPU, memory bandwidth saturation slowing all operations, or cache management overhead (copying, allocation) exceeding computational savings for short sequences.
Flash Attention reduces memory usage. How does this interact with KV cache benefits? Flash Attention optimizes attention computation memory (intermediate activations), while KV cache optimizes recomputation across time steps. They’re complementary - Flash Attention reduces per-step memory, KV cache reduces cross-step computation.
In transformer training vs inference: why is KV cache irrelevant for training? Training uses teacher forcing - processes entire sequences in parallel with full attention matrices. No autoregressive generation step-by-step, so no opportunity to reuse previous computations.
Model uses rotary positional embeddings (RoPE). How does this affect what we cache? Cache unrotated K,V vectors and apply position-dependent rotation during attention computation. This is because the rotation depends on the relative position between query and key, which changes for each new token.
A sequence has repeated phrases. Could we compress KV cache by deduplicating similar K,V vectors? Theoretically possible, but positional embeddings make even identical tokens have different representations. The attention mechanism depends on position, not just content, making deduplication complex and potentially harmful to model quality.
Mixture of Experts (MoE) model: different tokens activate different experts. How does this complicate KV cache? Each expert may need separate KV caches if they have different dimensions, or the routing decisions affect which cached values are relevant. Cache size scales with number of active experts per token.
Streaming generation: user types while model generates. How do you update KV cache mid-generation? Invalidate cache from the interruption point onward, recompute from the user’s new input position. This requires careful bookkeeping of which cache entries correspond to which part of the conversation state.
Speculative decoding: generate multiple tokens in parallel, then verify. How does KV cache work here? Cache optimistically for all speculated tokens, but maintain checkpoints to rollback if speculation fails. This creates a tree of potential cache states that must be managed efficiently.
Model pruning removes 50% of attention heads. How does this affect KV cache memory and performance trade-offs? Memory usage halves (n_heads reduces by 50%), but attention quality may degrade, potentially requiring longer sequences for same performance. Need to rebalance the memory savings vs quality trade-off.
We store K,V per head in KV cache. Do we also need to store the concatenated output after multi-head attention? No, only store individual K,V per head. The concatenated output gets recomputed each time because it depends on the new query. KV cache only stores the reusable components (keys and values), not query-dependent results.
When we say the KV cache memory “grows linearly with sequence length,” do we mean input prompt length or total length including generated tokens? Total length including all generated tokens. During autoregressive generation, sequence length = original_prompt + tokens_generated_so_far. Each new token increases this total length by 1.
Batch processing sequences of different lengths with padding: do shorter sequences waste KV cache memory? In naive implementations, yes - all sequences get padded to max_length, wasting memory. Better implementations use attention masks to ignore padding during computation and variable-length caching to avoid storing padding positions in the KV cache.
What is Grouped Query Attention (GQA): multiple heads share K,V. How does this change the storage formula for a model using a KV cache? GQA is multiple heads share K,V. Formula becomes: 2 × n_bytes × n_layers × n_kv_heads × d_head (where n_kv_heads < n_heads). If 32 query heads share 8 KV heads, you get 4× memory savings while maintaining most of the attention quality.
Beam search: each beam needs its own KV cache. How does this explode memory? Memory scales as k_beams × sequence_length × cache_size. With 8 beams and 1000 tokens, you need 8× more memory than greedy decoding. When beams split from same parent, cache copying creates temporary memory spikes beyond this 8× baseline.
What memory and time overheads does KV cache add? Memory: 2×n_bytes×n_layers×n_heads×d_head×B×L storage overhead. Time: memory allocation/deallocation, cache management, and quadratic read bandwidth growth (reading more cached data for each new token). These can outweigh benefits for short sequences.
Can we be memory-bound during forward pass but FLOPS-bound during sampling? Yes! Forward pass reads massive amounts of cached K,V data (memory-bound). Sampling step does text generation, top-k filtering, probabilistic sampling (FLOPS-bound). Different parts of inference have different bottlenecks depending on sequence length and model size.
Here are my notes, also written with claude.
FLOPS (Floating Point Operations Per Second) counts arithmetic operations (add, multiply, etc.).
Matrix-Vector Multiplication: A(m×k) @ v(k×1) = 2×m×k
FLOPS (k multiply-adds for each of m output elements). The factor of 2 comes from each output element requiring k multiplications + k additions.
Matrix-Matrix Multiplication: A(m×k) @ B(k×n) = 2×m×k×n
FLOPS (k multiply-adds for each of m×n output elements). The factor of 2 comes from each output element requiring k multiplications + k additions.
Compute Bound (FLOPs Bound): The bottleneck is arithmetic operations. GPU compute units are fully utilized, but memory can supply data faster than it’s consumed. Adding more compute power would speed up the process.
Memory Bound (Bandwidth Bound): The bottleneck is data movement. GPU compute units are idle waiting for data from memory. Memory bandwidth cannot supply data fast enough for the available compute. Adding more memory bandwidth would speed up the process.
Parameter Definitions:
Matrix Dimensions & Descriptions:
Full Attention Equation: Attention(Q,K,V) = softmax(QK^T / √d_model) @ V
FLOPS Breakdown:
2 × B × L × d_model²
FLOPS2 × B × L × d_model²
FLOPS2 × B × L × d_model²
FLOPS2 × B × L² × d_model
FLOPSC₁ × B × L²
FLOPS (constant operations per element)2 × B × L² × d_model
FLOPSTotal FLOPS: 6 × B × L × d_model² + 4 × B × L² × d_model + C₁ × B × L²
KV Cache Impact: Only the first term 6 × B × L × d_model²
changes with KV caching. The terms 4 × B × L² × d_model + C₁ × B × L²
do not change with KV cache. For simplicity, we only focus on this firt term in the rest of this post. But in practice you must add back the two unchanging terms.
KV Cache stores the computed key and value vectors from previous tokens during autoregressive generation, avoiding redundant recomputation. Instead of recalculating K and V for all tokens at each step, we cache them and only compute new entries.
Example
Prompt: "The weather today"
Target: Generate " is sunny"
Prefill Phase:
- Input: ["The", "weather", "today"] (3 tokens)
- Compute: K₁,V₁, K₂,V₂, K₃,V₃ in parallel
- Cache: Store all 3 K,V pairs
- Time: Fast parallel processing
Generation Phase:
Token 4 " is":
- Read: K₁,V₁, K₂,V₂, K₃,V₃ from cache
- Compute: Q₄, K₄, V₄
- Attention: Q₄ @ [K₁,K₂,K₃,K₄]^T @ [V₁,V₂,V₃,V₄]
- Cache: Append K₄,V₄
Token 5 " sunny":
- Read: K₁,V₁, K₂,V₂, K₃,V₃, K₄,V₄ from cache
- Compute: Q₅, K₅, V₅
- Attention: Q₅ @ [K₁,K₂,K₃,K₄,K₅]^T @ [V₁,V₂,V₃,V₄,V₅]
- Cache: Append K₅,V₅
Implementation To store the previous keys and value, the KV cache is defined as a tensor
cache = Tensor.empty(2, B, L, n_layers, n_heads, d_head)
where:
cache[0]
and the values are cache[1]
. The kv cache works by reading from and writing to this tensor.To find the key from the 6th token in the 3rd sequence at the 9th head from 14th layer of the model, we would do
cache[0, 3, 6, 14, 9]
Naively, this kind of naive KV-cache only works for greedy sampling and it does not work for more complicated sampling schemes like min-p or beam decoding.
Memory Requirement: This requires storing 2 × B × L × n_layers × n_heads × d_head × n_bytes
total bytes in memory because the cache is a tensor of shape 2, B, L, n_layers, n_heads, d_head
and each element in that tensor takes up n_bytes
.
Prefill Phase (Time to First Token - TTFT):
Generation Phase (Time Between Tokens):
When FLOPS Bound (KV Cache Helps):
When Memory Bound (KV Cache May Hurt):
Example: Long sequences (1000+ tokens) often become memory bound because reading cached data dominates the time, making the FLOPS savings irrelevant.
Advantages | Disadvantages |
---|---|
3× FLOPS reduction for Q,K,V computation | Linear memory growth with sequence length |
Enables real-time generation for interactive applications | Reduces maximum batch size due to memory constraints |
Critical for long context models (would be unusably slow otherwise) | Memory bandwidth bottleneck for very long sequences |
Powers production chat systems at scale | Beam search memory explosion when caches are copied |
Amortizes compute cost over conversation length | No benefit for training (only inference optimization) |
Enables streaming responses for better UX | Overhead for short sequences (< 50 tokens) |
Facilitates longer conversations without timeout | Cache sharing impossible between different users |
Generation | Without KV Cache | With KV Cache | Explanation |
---|---|---|---|
Single token (Lth token, batch size B) | 3 × 2 × d_model² × B + C | 2 × d_model² × B + C | Without: compute Q,K,V from scratch (3 matrices × 2 FLOPS factor). With: only compute Q, reuse cached K,V. C = attention matrix operations (same for both) |
Cumulative tokens (tokens 1, …, L, batch size B) | 3 × 2 × d_model² × B × L + C × L | 2 × d_model² × B × L + C × L | Without: recompute Q,K,V for every token. With: compute Q for each new token only. C scales with sequence length |
Generation | Memory Stored | Bandwidth: Read | Bandwidth: Write | Explanation |
---|---|---|---|---|
Single token (Lth token, batch size B) | 2 × n_bytes × n_layers × n_heads × d_head × B × L | 2 × n_bytes × n_layers × n_heads × d_head × B × (L-1) | 2 × n_bytes × n_layers × n_heads × d_head × B | Memory: total cache size grows with L. Read: access all previous tokens for attention. Write: store current token’s K,V |
Cumulative tokens (tokens 1, …, L, batch size B) | 2 × n_bytes × n_layers × n_heads × d_head × B × L | 2 × n_bytes × n_layers × n_heads × d_head × B × L × (L-1)/2 | 2 × n_bytes × n_layers × n_heads × d_head × B × L | Memory: same final size. Read: cumulative bandwidth = 0+1+2+…+(L-1) = L(L-1)/2 per batch. Write: one entry per token generated |
Definition: Beam search maintains the k most promising sequences at each step, exploring multiple paths simultaneously to find higher-quality outputs than greedy decoding.
Algorithm:
Concrete Example with KV Cache:
Model: 2 layers, 4 heads, d_head=64, n_bytes=2
Beam width k=2, candidates per beam r=3
Starting sequence: "The weather" (2 tokens)
Initial State:
Beam 1: "The weather"
KV Cache 1: K₁,V₁ (for "The"), K₂,V₂ (for "weather")
Cache size: 2×2×2×4×64 = 1,024 bytes
Beam 2: "The weather"
KV Cache 2: K₁,V₁ (for "The"), K₂,V₂ (for "weather")
Cache size: 1,024 bytes (same starting point)
Step 1: Generate third token (each beam explores 3 candidates)
Beam 1 candidates: "is" (0.6), "was" (0.3), "feels" (0.2)
Beam 2 candidates: "today" (0.5), "looks" (0.4), "seems" (0.1)
Step 2: Select top-2 sequences across all 6 candidates
Selected: "The weather is" (0.6), "The weather today" (0.5)
MEMORY EXPLOSION POINT:
New Beam 1: "The weather is"
- Needs: COPY of original Beam 1 cache + new K₃,V₃ for "is"
- Cache: K₁,V₁, K₂,V₂, K₃ⁱˢ,V₃ⁱˢ = 1,536 bytes
New Beam 2: "The weather today"
- Needs: COPY of original Beam 2 cache + new K₃,V₃ for "today"
- Cache: K₁,V₁, K₂,V₂, K₃ᵗᵒᵈᵃʸ,V₃ᵗᵒᵈᵃʸ = 1,536 bytes
Total Memory: 3,072 bytes (3× explosion from beam splitting!)
Note: The beams started identical but diverged due to different candidate selection
Generation | FLOPS (vs Greedy) | Memory Stored | Bandwidth: Read | Explanation |
---|---|---|---|---|
Single token (Lth token, k beams) | k × (2 × d_model² + C) | k × 2 × n_bytes × n_layers × n_heads × d_head × L | k × 2 × n_bytes × n_layers × n_heads × d_head × (L-1) | k separate sequences, each needs own KV cache and computation |
Cumulative tokens (tokens 1, …, L, k beams) | k × (2 × d_model² × L + C × L) | k × 2 × n_bytes × n_layers × n_heads × d_head × L | k × 2 × n_bytes × n_layers × n_heads × d_head × L × (L-1)/2 | Memory can temporarily spike beyond k× during beam splitting when caches must be copied |
Key Insight: Beam search with KV cache uses k× more memory than greedy decoding, but beam splitting during search can cause temporary memory explosions when multiple beams inherit from the same parent cache.
FLOPS Bound (Compute Bound): The bottleneck is arithmetic operations. GPU compute units are fully utilized, but memory can supply data faster than it’s consumed. Adding more compute power would speed up the process.
Memory Bound (Bandwidth Bound): The bottleneck is data movement. GPU compute units are idle waiting for data from memory. Memory bandwidth cannot supply data fast enough for the available compute. Adding more memory bandwidth would speed up the process.
When FLOPS Bound (KV Cache Helps):
When Memory Bound (KV Cache May Hurt):
Example: Long sequences (1000+ tokens) often become memory bound because reading cached data dominates the time, making the FLOPS savings irrelevant.
Some great resources I found: