Ensure the query_states and key_states remain in bf16 (#21)
Browse files- Ensure the query_states and key_states remain in bf16 (55b8e963ff0aa4a4190cff537165f08c378f62ff)
Co-authored-by: Mohammadreza Mohseni <[email protected]>
- positional_embedding.py +2 -2
positional_embedding.py
CHANGED
|
@@ -269,10 +269,10 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 269 |
return (
|
| 270 |
apply_rotary_pos_emb(
|
| 271 |
q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
| 272 |
-
),
|
| 273 |
apply_rotary_pos_emb(
|
| 274 |
k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
| 275 |
-
),
|
| 276 |
)
|
| 277 |
|
| 278 |
@classmethod
|
|
|
|
| 269 |
return (
|
| 270 |
apply_rotary_pos_emb(
|
| 271 |
q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
| 272 |
+
).to(q.dtype),
|
| 273 |
apply_rotary_pos_emb(
|
| 274 |
k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
|
| 275 |
+
).to(k.dtype),
|
| 276 |
)
|
| 277 |
|
| 278 |
@classmethod
|