getting Inference Error

#18
by Gyaneshere - opened

Getting an error , when trying to use Dia-1.6B

gradio.exceptions.Error: 'Inference failed: Expected query, key, and value to have the same dtype, but got query.dtype: c10::BFloat16 key.dtype: float and value.dtype: float instead.'

Screenshot 2025-04-23 173719.png

The problem is in the /dia/layers.py file in the forward function.

Before this line:
attn_output = F.scaled_dot_product_attention(

Just add:
if attn_k is not None and attn_v is not None:
attn_k = attn_k.to(Xq_BxNxTxH.dtype)
attn_v = attn_v.to(Xq_BxNxTxH.dtype)

There is a new version of layers.py in their Gihub a few hours old so perhaps you can fix this by just updating to the latest version:
https://github.com/nari-labs/dia/tree/main/dia

For more information, please visit this web page:
https://github.com/devnen/Dia-TTS-Server/issues/1

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment