The Magic of the KV Cache

Date: September 13, 2025

At least for now, inside each Transformer based LLM service one of the most important optimizations is the KV Cache.

I'll explore what it is, walk through a concrete example with matrix sizes, and explain why this "magic" is reserved for inference and not used during training.

The Problem: The Quadratic Attention Matrix of Inference

First, let's remember the core of the Transformer: the Scaled Dot-Product Attention formula.

$$ \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) V $$

Here, $Q$ (Query), $K$ (Key), and $V$ (Value) are tensors. For a sequence of length $N$, their shape is $(N, d_k)$, where $d_k$ is the dimension of the attention head. Inside the softmax, the $Q K^T$ operation is a matrix multiplication of shape $(N \times d_k) \cdot (d_k \times N)$ which results in the $(N \times N)$ attention matrix. The softmax is applied on each row of this matrix, and the resulting $(N \times N)$ matrix provides the weights for $V$.

During inference (decoding stage), a model works autoregressively—it generates one new token based on all the tokens it has seen so far. Imagine we've generated the sequence "The train arrives". To generate the next word, the model needs to calculate attention based on this 3-word context.

A naive approach would look like this:

  1. Input: "The train arrives"
  2. Generate Token 4: Compute Q, K, V for all 3 words. Calculate a $3 \times 3$ attention matrix. Generate the word "at".
  3. Input: "The train arrives at"
  4. Generate Token 5: Compute Q, K, V for all 4 words. Calculate a new $4 \times 4$ attention matrix. Generate the word "12".

Do you see the problem? To generate the 5th token, we recalculated the Key and Value vectors for "The", "train", and "arrives" even though we just calculated them in the previous step. This is incredibly wasteful. As the sequence gets longer ($N$), the computation at each step scales quadratically, $O(N^2)$, because of that $Q K^T$ matrix multiplication.

The Solution: Don't Recalculate, Cache!

The core insight behind the KV Cache is simple but profound:

At inference time, for any given token in a sequence, its Key and Value vectors will never change. They only depend on the token's embedding and the model's (fixed) weights. It's only the Query vector that needs to be new for each generation step, as it represents the new token. At decoding stage the new token's attention score is the last row of outer product of $Q K^T$ dot product all of the $V$. So even thought conceptually the attention matrix increased from $(N -1 \times N-1)$ to $(N \times N)$, the inference only care the last score for generation of a new token.

The KV Cache is simply a storage space in memory that holds the Key and Value vectors for all the preceding tokens in the sequence.

A Concrete Example: Let's Build a KV Cache

Let's set up a scenario for a single attention head in a decoder-only transformer.

Hyperparameters:

The model has three learned weight matrices:

Stage 1: The Prefill Phase (Processing the Prompt)

The user provides the prompt: "Hello world".

  1. Input: The model receives the embeddings for "Hello" and "world". This is a matrix $X$ of size $2 \times 512$.
  2. Calculate K and V: The model computes the Key and Value for the entire prompt at once.
    • $K_{\text{prompt}} = X W_k$ => $(2 \times 512) \cdot (512 \times 64)$ => $K_{\text{prompt}}$ is $2 \times 64$
    • $V_{\text{prompt}} = X W_v$ => $(2 \times 512) \cdot (512 \times 64)$ => $V_{\text{prompt}}$ is $2 \times 64$
  3. Store in Cache: These two matrices are now stored in our KV Cache. The cache now holds 2 tokens' worth of state.
    • $K_{\text{cache}} = K_{\text{prompt}}$ (size $2 \times 64$)
    • $V_{\text{cache}} = V_{\text{prompt}}$ (size $2 \times 64$)

The model also calculates $Q_{\text{prompt}}$ and performs a full, causally-masked attention calculation to understand the prompt and predict the very first new token. Let's say it predicts "I".

Stage 2: The Generation Phase (Generating Token 3)

Now the efficient, step-by-step process begins.

  1. Input: The model's input is only the embedding for the new token, "I". This is a matrix $x_{\text{new}}$ of size $1 \times 512$.
  2. Calculate New q, k, v: The model computes vectors for this single token only.
    • $q_{\text{new}} = x_{\text{new}} W_q$ => $(1 \times 512) \cdot (512 \times 64)$ => $q_{\text{new}}$ is $1 \times 64$
    • $k_{\text{new}} = x_{\text{new}} W_k$ => $(1 \times 512) \cdot (512 \times 64)$ => $k_{\text{new}}$ is $1 \times 64$
    • $v_{\text{new}} = x_{\text{new}} W_v$ => $(1 \times 512) \cdot (512 \times 64)$ => $v_{\text{new}}$ is $1 \times 64$
  3. Retrieve and Append to Cache: The model pulls the existing $K_{\text{cache}}$ and $V_{\text{cache}}$ and appends the new $k_{\text{new}}$ and $v_{\text{new}}$.
    • $K_{\text{full}} = \text{concat}(K_{\text{cache}}, k_{\text{new}})$ => $K_{\text{full}}$ is now $3 \times 64$
    • $V_{\text{full}} = \text{concat}(V_{\text{cache}}, v_{\text{new}})$ => $V_{\text{full}}$ is now $3 \times 64$

    These larger $K_{\text{full}}$ and $V_{\text{full}}$ matrices become the new state of the cache.

  4. Calculate Attention: This is the magic step. We use our single new query $q_{\text{new}}$ to attend to all the keys we've ever seen.
    • $\text{scores} = q_{\text{new}} K_{\text{full}}^T$ => $(1 \times 64) \cdot (64 \times 3)$ => scores is $1 \times 3$

Notice we did not create a $3 \times 3$ matrix. We only computed the single row of scores we need. This reduces the complexity of this step from $O(N^2)$ to $O(N)$.

Get Final Output: We apply softmax to our $1 \times 3$ scores and multiply by $V_{\text{full}}$ to get our final context vector, which is then passed to the next layer.

This process repeats for every new token, with the cache growing by one row at each step.

Why Don't We Use a KV Cache During Training?

If this is so efficient, why not use it during training? There are two fundamental reasons.

1. Backpropagation Needs the Full Matrix

Training is all about updating the model's weights via backpropagation. To calculate the gradients for the weight matrices ($W_q$, $W_k$, $W_v$), the algorithm needs to know how every part of the computation contributed to the final error. This requires having the full, $N \times N$ attention score matrix (called an "activation") available in memory from the forward pass to use during the backward pass. The KV cache method only calculates one row, losing the information needed for the gradient calculations of all other rows.

2. Training Thrives on Parallelism

During training, we use "Teacher Forcing," where the model is given the entire correct sequence at once. GPUs/ASICs are masters of parallel computation. Calculating one giant $N \times N$ matrix multiplication is vastly more efficient than performing $N$ separate sequential calculations. This allows the model to calculate the loss for every single token's prediction (e.g., given "The" predict "train"; given "The train" predict "arrives") all at the same time, making training massively faster.

Conclusion

The KV Cache is not a change to the Transformer architecture, it's a brilliant implementation detail that makes inference efficient. By trading a small amount of memory (to store the K and V vectors) for a massive reduction in redundant computation, it transforms the attention mechanism from a quadratic computation into a linear-time operation for each new token generated.