gpt-oss-20b. flash-attention-style implementation does not work on second inference
I copied the flash-attention-style implementation from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
On the first token inference max error is fine = ~1e-2, but on second token it gets huge.
What am I doing wrong here?
Comments:
n_ctx = number of tokens in the current query (1 in our case)
n_kv_ctx = total number of queries including current query (say, 85)
bandwidth = sliding_window (128 or None)
start_q = 85 in our case (84 does not seem to work as well)
Attention forward:
def forward(ctx, q, k, v, sinks, sm_scale, bandwidth, start_q):
bs, n_ctx, n_kv_ctx, n_heads, HEAD_DIM_K, HEAD_DIM_V = q.shape[0], q.shape[2], q.shape[2], 64, 64, 64
BLOCK_M = 64
BLOCK_N = 64
m_pad_size = BLOCK_M - n_ctx % BLOCK_M if n_ctx % BLOCK_M != 0 else 0
# pad q to multiple of its block size in the n_ctx dimension (-2)
q = torch.nn.functional.pad(q, (0, 0, 0, m_pad_size))
n_pad_size = BLOCK_N - n_kv_ctx % BLOCK_N if n_kv_ctx % BLOCK_N != 0 else 0
# pad k and v to multiple of their block size in the n_kv_ctx dimension
k = torch.nn.functional.pad(k, (0, 0, 0, n_pad_size))
v = torch.nn.functional.pad(v, (0, 0, 0, n_pad_size))
o = torch.empty_like(q)
M = torch.empty((bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32)
grid = (triton.cdiv(n_ctx, BLOCK_M), bs * n_heads, 1)
_attn_fwd[grid](
TensorDescriptor.from_tensor(q, [1, 1, BLOCK_M, HEAD_DIM_K]),
TensorDescriptor.from_tensor(k, [1, 1, BLOCK_N, HEAD_DIM_K]),
TensorDescriptor.from_tensor(v, [1, 1, BLOCK_N, HEAD_DIM_K]),
sinks,
sm_scale,
M,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, HEAD_DIM_K]),
start_q,
q.shape[0],
q.shape[1],
N_Q_CTX=n_ctx + m_pad_size,
N_KV_CTX=n_kv_ctx,
HEAD_DIM=HEAD_DIM_K,
BANDWIDTH=bandwidth,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
ctx.save_for_backward(q, k, v, sinks, o, M, start_q)
ctx.sm_scale = sm_scale
ctx.bandwidth = bandwidth
o = o[:, :, :n_ctx, :].transpose(1, 2).contiguous()
#o = o.view(bs, n_ctx, n_heads * HEAD_DIM_V) #will make it in AttentionBlock
return o
===============================================================
eager_attention_forward modification:
def my_eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
sliding_window=None,
s_aux = None,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
offset, n_ctx = min(key.shape[2] - query.shape[2], sliding_window if sliding_window else 0), query.shape[2]
if offset>0: offset = offset - 1
print("offset", query.shape, key.shape, offset, "n_ctx:", n_ctx, "sliding_window:", sliding_window, "scaling:", scaling, kwargs)
start_q = torch.LongTensor([offset]).to(query.device)
t = attention(
query,
key_states,
value_states,
s_aux, #sinks,
scaling,
sliding_window,
start_q
)
#===================================
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
# This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16
# when training with bsz>1 we clamp max values.
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
scores = probs[..., :-1] # we drop the sink here
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
diff = (attn_output - t).abs().max().item()
print("my_eager_attention_forward:", attn_output.shape, t.shape, attn_output.flatten()[-5:], t.flatten()[-5:], "Error:", diff, "\n")
return attn_output, attn_weights
I would appreciate any help.
Thanks!