getting Inference Error
Getting an error , when trying to use Dia-1.6B
gradio.exceptions.Error: 'Inference failed: Expected query, key, and value to have the same dtype, but got query.dtype: c10::BFloat16 key.dtype: float and value.dtype: float instead.'
The problem is in the /dia/layers.py file in the forward function.
Before this line:
attn_output = F.scaled_dot_product_attention(
Just add:
if attn_k is not None and attn_v is not None:
attn_k = attn_k.to(Xq_BxNxTxH.dtype)
attn_v = attn_v.to(Xq_BxNxTxH.dtype)
There is a new version of layers.py in their Gihub a few hours old so perhaps you can fix this by just updating to the latest version:
https://github.com/nari-labs/dia/tree/main/dia
For more information, please visit this web page:
https://github.com/devnen/Dia-TTS-Server/issues/1