PyTorch support ?

#9
by AliceBeta - opened

I tested flash_attn-2.7.4.post1+cu128torch2.7.0, seems not work.

>>> import torch
>>> import flash_attn
>>> torch.__version__
'2.7.1+cu128'
>>> flash_attn.__version__
'2.7.4.post1'
>>> torch.backends.cuda.is_flash_attention_available()
False

It seems like a problem with PyTorch, sdp_kernel cannot use flash attn directly?

>>> from transformers.utils import is_flash_attn_2_available
>>> print(is_flash_attn_2_available())
True
AliceBeta changed discussion title from PyTorch 2.7.1 support ? to PyTorch support ?

Downgrading to PyTorch 2.7.0 fails with flash_attn, Ampere architecture GPUs.

RuntimeError: CUDA error: no kernel image is available for execution on the device

The relevant content and discussion are rather trivial, and the current valid reference should be #108175. Although FA2 has supported Windows for some time, PyTorch seems a bit passive about being compatible with it.

Sign up or log in to comment