Liger-Kernel
Liger-Kernel copied to clipboard
fused_linear_cross_entropy: Move float32 cast into kernel
Summary
Another small optimization :) The logits_chunk.float() allocation may be surprisingly large, e.g. Cohere models have 256K vocabs, so each logit chunk in float32 could be something like 1024 * 256K * 4 = 1GB VRAM (even more if the chunk size is larger.)
I actually don't think any explicit casting is even required within the Triton kernel since the intermediate softmax calculation variables like m, d, etc. are already float32 by default, so with type promotion the calculations should all be float32 regardless.
However, I added explicit casts .cast(tl.float32) around all of the X_ptr loads to make this more obvious to the reader. In either case, the actual liger_cross_entropy_kernel runs so quickly that I don't think there's any performance difference - this is purely to save the float32 allocation. (It might be more efficient without the explicit casts, but I was not able to measure anything - even with a 1K x 256K logit matrix the kernel kind of runs instantly lol.)
Testing Done
- Hardware Type: A100 80GB
- [x] run
make testto ensure correctness - [x] run
make checkstyleto ensure code style - [x] run
make test-convergenceto ensure convergence
thanks! this makes sense. i did try the similar thing before but seen divergence compared with casting from the torch side (not sure why, maybe i did it wrong). also, currently bfloat16 convergence test is not actually tested due to https://github.com/linkedin/Liger-Kernel/issues/176. after the fix is merged, we can try to run on convergence tests with bf16 to see if there is any gap.
I added a test_float32_internal() unit test which runs the kernel twice, once with a bfloat16 input and once with a float-upcasted version, and verifies that the resulting output (in bfloat16) is exactly identical 🙂 You can also verify that the test passes even without the explicit .cast(tl.float32) calls, so maybe those could be removed as long as the test is present..
cool! I will take a deeper look today or tomorrow. This is exciting!
Can we merge this? in current form liger kernel broken.
@hansonw can we resolve the conflict? ty
We can merge this once the conflict is resolved. thanks!!
@ByronHsu I rebased this PR and fixed the conflicts, but not sure of the best way to help get the changes into this PR. see https://github.com/winglian/Liger-Kernel/pull/new/hw-fused-ce-float
@pramodith would you interested in moving this forward? maybe create a new PR based on the current main. We just have to ensure the precision is correct
@pramodith would you interested in moving this forward? maybe create a new PR based on the current main. We just have to ensure the precision is correct
@ByronHsu done #406