Flashattention 2 support?

#14
by t-albertge - opened

Hi there,

would it be possible to have FlashAttention-2 support to the model? I think the modeling code already uses torch's spda kernel in the forward call of LLaDABlock, but is it possible to have flashattention-2? Thanks!

Sign up or log in to comment