Fix BF16 training
#19
by
alexanderchemeris
- opened
For long sequences, this calculation yields an incorrect result due to a lower number of bits in the mantissa of BF16. E.g., for 640
elements, this produces valid_lengths = 639
. Converting this early to long solves the issue.
Thank you for pointing out this issue. But I think this issue has already been fixed by line 133:
mask = x[:, :, -1].long()
Do you think there are still issues with the current code?
Sorry, I didn't notice this recent fix. I think it's equivalent. I'll check and come back.
xiezhe24
changed pull request status to
closed