Problem loading the model in bf16

#132
by LouisDo2108 - opened

Hi @jupyterjazz ,

The model weights are in bf16; however, performing encoding in bf16 with adapter_mask results in the following error:

RuntimeError: Index put requires the source and destination dtypes match, got BFloat16 for the destination and Float for the source.

I've inspected the code for a while, and it appears that many operations in the custom xlm-roberta-flash-implementation's LoRAParametrization, lora_forward function casts the weight back to fp32.

I have been able to solved this by setting torch.amp.autocast('cuda', dtype=torch.bfloat16)

LouisDo2108 changed discussion status to closed

Sign up or log in to comment