Problem loading the model in bf16
#132
by
LouisDo2108
- opened
Hi @jupyterjazz ,
The model weights are in bf16
; however, performing encoding in bf16
with adapter_mask
results in the following error:
RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.
I've inspected the code for a while, and it appears that many operations in the custom xlm-roberta-flash-implementation
's LoRAParametrization
, lora_forward
function casts the weight back to fp32
.
I have been able to solved this by setting torch.amp.autocast('cuda', dtype=torch.bfloat16)
LouisDo2108
changed discussion status to
closed