Post
1752
Triton nanoGPT now has a custom cross entropy loss kernel π
Next: matmul, gradually overthrowing all major PyTorch ops:)
Simplified pseudo for parallel cross-entropy loss compute:
- init program: get pid, compute offsets, load targets.
- init row_max and row_sum.
- for-loop1 (find max logits): update row_max with max logits.
- for-loop2 (compute softmax and loss): compute row_sum, update loss.
- add log(row_sum) and store loss.
Code: https://github.com/Jaykef/ai-algorithms/blob/main/triton_nanoGPT.ipynb
Next: matmul, gradually overthrowing all major PyTorch ops:)
Simplified pseudo for parallel cross-entropy loss compute:
- init program: get pid, compute offsets, load targets.
- init row_max and row_sum.
- for-loop1 (find max logits): update row_max with max logits.
- for-loop2 (compute softmax and loss): compute row_sum, update loss.
- add log(row_sum) and store loss.
Code: https://github.com/Jaykef/ai-algorithms/blob/main/triton_nanoGPT.ipynb