Does transformers utilize PyTorch SDPA's flash_attention for openai/gpt-oss-20b?
I'm investigating if the flash_attention backend from PyTorch's scaled_dot_product_attention (SDPA) is leveraged when running the openai/gpt-oss-20b model via the transformers library. How can we verify this behavior? I'm looking for methods, code snippets, or official documentation that confirm whether this optimization is active by default or if specific configurations are required to enable it.
You can switch between various attn_implementations like here: https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.attn_implementation
You can switch between various
attn_implementations like here: https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.attn_implementation
it looks like gpt-oss model doesn't support sdpa,
i got
raise ValueError(
ValueError: GptOssForCausalLM does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument attn_implementation="eager" meanwhile. Example: model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")
Edit / Correction:
I am able to load the model with FA2 and it works to reduce memory usage on OSS:
model = AutoModelForCausalLM.from_pretrained(
model_slug,
torch_dtype=torch.bfloat16, # bf16 on H100/H200
device_map="auto",
attn_implementation="flash_attention_2",
)
I upgraded to the latest transformers. Note fyi that this appears not (yet?) to work on unsloth - for reasons I don't understand.
Original post:
Yeah @reach-vb this is a pretty big bottleneck meaning you can't really train OSS on reasoning.
SPDA isn't supported for OSS (you can see the fallback to eager logged explicitly when you load with unsloth).
It would be nice if attn_implementation set to FA2 worked, but I have tried that and it also just causes a fallback to eager.