GPT-OSS Attention Mechanisms

Study of Grouped-Query, Sliding Windows, and the Attention Sink with Code

As an engineer, gpt-oss model card talks about "banded window attention". This led me on a journey to understand how the core attention mechanism has been optimized for efficiency. This post summarizes that journey, starting with the core code and moving to the key concepts that make these models feasible. The low-level GPU optimizations, such as FlashAttention, are interesting topics on its own, not dive in here.

1. The Core Attention Function

The code referenced in this post can be found on the GPT-OSS project's GitHub page: the scaled dot-product attention (SDPA) function.

This code shows several key concepts beyond the standard attention formula.


def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
    n_tokens, n_heads, q_mult, d_head = Q.shape
    assert K.shape == (n_tokens, n_heads, d_head)
    assert V.shape == (n_tokens, n_heads, d_head)
    # GQA: Expand K, V to match Q's shape for dot product
    K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
    V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
    
    # Sink token for numerical stability
    S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)
    
    # Causal Mask (hides future tokens)
    mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
    
    # Sliding Window Mask (restricts attention to a window)
    if sliding_window > 0:
        mask += torch.tril(
            mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
        )
    
    QK = torch.einsum("qhmd,khmd->hmqk", Q, K)
    QK *= sm_scale
    QK += mask[None, None, :, :]
    QK = torch.cat([QK, S], dim=-1)
    W = torch.softmax(QK, dim=-1)
    W = W[..., :-1]
    attn = torch.einsum("hmqk,khmd->qhmd", W, V)
    return attn.reshape(n_tokens, -1)
                

2. Decoding the Code: A Step-by-Step Breakdown

Let's break down the purpose of each key part of this function. The magic lies in how the model uses these components to balance performance with model quality.

Step 1: Tensor Shapes and Grouped-Query Attention (GQA)

The first lines unpack the tensor shapes: n_tokens, n_heads, q_mult, and d_head. In the context of this code, the model uses Grouped-Query Attention (GQA). In GQA, multiple query heads share a single set of key (K) and value (V) heads within a "group". The q_mult variable represents the number of query heads in each group. This is different from a standard multi-head attention where each query head has its own KV head, and different from multi-query attention where all query heads share one KV head.

The lines K = K[:, :, None, :].expand(-1, -1, q_mult, -1) and V = V[:, :, None, :].expand(-1, -1, q_mult, -1) are crucial here. Since the K and V tensors are shared across a group of q_mult query heads, they have a smaller shape. The .expand() operation effectively duplicates the tensors to match the number of query heads, preparing the data for the efficient batched matrix multiplication (torch.einsum) that follows.

Step 2: Causal Masking and Sliding Window

The mask is a core component. The line mask = torch.triu(..., diagonal=1) creates an upper triangular matrix of -inf values. This is the causal mask, which ensures that during training, each token can only attend to tokens that came before it. This is a fundamental requirement for autoregressive models that generate text sequentially.

Causal Mask Visualization (6x6 example)

0
-∞
-∞
-∞
-∞
-∞
0
0
-∞
-∞
-∞
-∞
0
0
0
-∞
-∞
-∞
0
0
0
0
-∞
-∞
0
0
0
0
0
-∞
0
0
0
0
0
0

Here, -∞ represents -infinity. The model cannot "see" future tokens.

The if sliding_window > 0 block adds a second mask. By adding a lower triangular matrix of -inf values, it restricts each token to only attending to a fixed-size window of tokens immediately preceding it. This "banded" attention pattern is a form of sparse attention. It limits the computational complexity from a quadratic relationship (as in a standard full attention matrix) to a linear one, or from O(N^2) to O(N), which is essential for scaling to very long sequences.

Sliding Window Mask (6x6 example, window=3)

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
-∞
0
0
0
0
0
-∞
-∞
0
0
0
0
-∞
-∞
-∞
0
0
0

This mask prevents attention to tokens outside the sliding window (here, a window of 3).

Step 3: The "Sink" Vector for Numerical Stability

The S tensor, or sink vector, is a fascinating detail. This tensor is a trainable parameter whose values are learned during training. Its role is purely for numerical stability. In sparse attention, a token's attention window might be completely masked, leading to all-negative-infinity attention scores. Applying the softmax function to these scores would result in a NaN, crashing the training. By concatenating the sink's scores to the attention matrix (QK = torch.cat([QK, S], dim=-1)), we guarantee at least one finite value. This prevents the NaN output and stabilizes the training.

Visualizing the Problem and Solution

1. The Problem: Fully Masked Attention

When a token has no valid tokens to attend to, its scores become `-infinity`.

Masked Scores

-∞
-∞
-∞

Softmax Result

NaN
NaN
NaN

This `NaN` result crashes training.

2. The Solution: The Sink Vector

A learned value `S` is added, guaranteeing a finite number for the softmax.

Scores + Sink

-∞
-∞
-∞
S

Softmax Result

0
0
0
1

The weight for `S` is then dropped, leaving a valid `[0, 0, 0]` distribution.

The line W = W[..., :-1] then removes the attention weight assigned to the sink. This ensures that the sink vector serves its purpose of stabilizing the softmax without influencing the final attention output.

Step 4: Training vs. Inference

It's important to understand that this code snippet is only a vanilla implementation for the training phase, where the entire sequence is available. For inference (generation), the process is different. A model generates one new token at a time, meaning n_tokens is always 1. The model uses a KV cache to store the keys and values of all previously generated tokens. The attention calculation for the new token's query is only performed against the keys and values of the previous tokens stored in the cache. With a sliding window, this attention is further restricted to only the last sliding_window tokens in the cache, keeping the computation constant and low for each new token generated.

Conclusion

These techniques—Grouped-Query Attention for KV cache efficiency, Sliding Window Attention for linear scalability with long contexts, and the Sink Vector for numerical stability—are interesting features that make modern large language models practical and performant.