With the new multi-backend modular system how do you intend on supporting "non vanilla" models? And will torch.compile be supported?

#9
by Avelina - opened

The new direction of the Transformers library looks like it will be targeting "backend agnostic" code which is possible because most Transformer models have the same structure. But how do you plan on supporting more experimental models? For example how would new attention architectures such as MLA work? How would you support transformer models which introduce skip connections outside of the residual stream? How would you support transformer models which have non-standard KV caches or apply rotary embeddings in a non standard way?

I am a researcher who does her best to adhere to the HF spec as closely as possible but often times there are parts of the spec which simply cannot be supported with more exotic models. How do you plan on supporting such models? And how do you plan on supporting non-attention based models such as RNNs/SSMs/RWKV?

In addition to this how do you plan on supporting torch.compile? I recent months your torch.compile compatibility has gone from bad to worse, with examples such as Flash Attention 2 becoming no longer being compatible due to data-dependent control flow in 4D attention mask construction -- something which used to work last year before the switch to modular models. Now I'm forced to use SDPA when training HF models with torch.compile for acceleration as the graph breaks introduced by your new modular FA2 implementation make performance worse than with SDPA.

Hey @Avelina ! If we go back to the example of MLA, support for efficient caching in downstream libraries would simply be a matter of passing the intermediate vectors (that need caching) as additional kwargs to the attention function. Then, the downstream library would need to make sure to use those vectors for caching instead of the traditional key and value states that are passed, and use the module's weights to recreate the past k/v states on-the-fly. This is possible since the attention module itself is an argument (so weights are available), and we let downstream libs handle the caching of past states.

All other modifications (i.e. skip connections, new rotary embeddings etc) are architectural choices that are part of the model definition itself, thus it's already agnostic to the backend used (it should work independently of the input format already).

For torch.compile and FA2, I'm a bit surprised given that we do not construct 4D masks for FA2 (as it works on unpadded inputs). Especially for training with FA2, if you want to be as efficient as possible, you should avoid passing a mask, and use packed sequences along with the seqlens in order to avoid all padding tokens (see https://huggingface.co/docs/transformers/en/main_classes/data_collator#transformers.DataCollatorWithFlattening). As you don't use a mask anymore, it should stay compile-compatible. But anyway, we're currently working on a massive refactor of how we create the masks (see https://github.com/huggingface/transformers/pull/37866) which will be much more general/efficient.

Hope this answers your questions! ๐Ÿค— Let me know if something is unclear/you have any other questions!

Heya @cyrilvallez thanks for the response! However there are issues with FA2 which occur regardless of if a mask is passed or not. When a mask is passed we get one type of graph break, when a mask isn't passed we get a different graph break here https://github.com/huggingface/transformers/blob/46a4b7c909fab58355936e1c7109bb5e2e558267/src/transformers/modeling_flash_attention_utils.py#L378

About a year ago when the modeling_.py modules still implemented their own FA2 forward methods this was not an issue, but now I'm forced to train with SDPA to avoid graph breaks. Graph breaks aren't necessarily the end of the world and before torch had custom op wrapper support it was normal for FA2 to cause a break in every attention layer, however these new breaks being a result of .item() or .all() are a whole different beast and makes FA2 with torch.compile slower than SDPA with torch.compile for both training and batched inference.

Sign up or log in to comment