RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800
#18
by
g-ronimo
- opened
Thank you for this model!
Any idea how to resolve this when finetuning (QLoRA) gemma-7B with FA2 on a 3090 ?
/home/g/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
warnings.warn(
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Traceback (most recent call last):
File "/home/g/gemma-ft/qlora-OA.py", line 262, in <module>
trainer.train()
File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1624, in train
return inner_training_loop(
File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1961, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2911, in training_step
self.accelerator.backward(loss)
File "/home/g/accelerate_fork/src/accelerate/accelerator.py", line 1966, in backward
loss.backward(**kwargs)
File "/home/g/.local/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
File "/home/g/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 319, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
File "/home/g/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 531, in backward
_flash_attn_backward(
File "/home/g/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 131, in _flash_attn_backward
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800
Traceback (most recent call last):
File "/home/g/gemma-ft/qlora-OA.py", line 262, in <module>
trainer.train()
File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1624, in train
return inner_training_loop(
File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1961, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/g/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2911, in training_step
self.accelerator.backward(loss)
File "/home/g/accelerate_fork/src/accelerate/accelerator.py", line 1966, in backward
loss.backward(**kwargs)
File "/home/g/.local/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
File "/home/g/.local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 319, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/g/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
File "/home/g/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 531, in backward
_flash_attn_backward(
File "/home/g/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 131, in _flash_attn_backward
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800
same error. It works on my instance for a mistral but not gemma
It's just that the Head dm of this model is bigger 😭 so another kernel is required it seems
g-ronimo
changed discussion status to
closed