File size: 1,704 Bytes
550eb56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Advanced Insights: Attention Masks with KV-Caching

## Key Pitfalls in Complex Attention Implementations

### Dimension Evolution with Caching
```python
# Crucial dimension transitions in cached attention:
[b, s, d_model] -> [b, s+cache, d_c] -> [b, s+cache, d_model] -> [b, num_h, s, d_head]
```
The non-obvious trap: even with growing K/V cache, attention output dimensions must match query length, not cached length.

### Mask Causality with Growing Cache
Standard causal masks break with KV-caching - they don't account for position-dependent attention patterns across cached sequences. Critical edge cases:
- Token at position `i` must attend to `[0:start_pos+i]`
- Naive mask extension leads to incorrect causality preservation
- Performance impact of position-wise mask generation

### Optimization Considerations
1. Memory vs Compute tradeoff: Precomputing extended masks vs generating per position
2. Batch dimension handling: Mask broadcasting impacts memory usage
3. Fused attention patterns may break with custom mask handling

## Debugging Strategy for Non-Obvious Cases
Monitor these dimension transitions for subtle bugs:
```python
C_KV.shape      # Should grow: [b, s₁, d_c] -> [b, s₁+s₂, d_c]
K_state.shape   # Post-projection growth affects attention patterns
att_output.shape # Must maintain query dimensions despite K/V growth
```

## Practical Example: DeepSeek's MLA Edge Case
In Multi-Latent Attention, the compressed KV cache introduces subtle interactions with attention masks due to:
1. Joint compression affecting position-dependent patterns
2. Non-standard dimension flow through compression/decompression
3. Mask causality preservation across cached compressed states