Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

fused_linear_cross_entropy: Move float32 cast into kernel

Open hansonw opened this issue 1 year ago • 6 comments
trafficstars

Summary

Another small optimization :) The logits_chunk.float() allocation may be surprisingly large, e.g. Cohere models have 256K vocabs, so each logit chunk in float32 could be something like 1024 * 256K * 4 = 1GB VRAM (even more if the chunk size is larger.)

I actually don't think any explicit casting is even required within the Triton kernel since the intermediate softmax calculation variables like m, d, etc. are already float32 by default, so with type promotion the calculations should all be float32 regardless.

However, I added explicit casts .cast(tl.float32) around all of the X_ptr loads to make this more obvious to the reader. In either case, the actual liger_cross_entropy_kernel runs so quickly that I don't think there's any performance difference - this is purely to save the float32 allocation. (It might be more efficient without the explicit casts, but I was not able to measure anything - even with a 1K x 256K logit matrix the kernel kind of runs instantly lol.)

Testing Done

  • Hardware Type: A100 80GB
  • [x] run make test to ensure correctness
  • [x] run make checkstyle to ensure code style
  • [x] run make test-convergence to ensure convergence

hansonw avatar Sep 09 '24 23:09 hansonw

thanks! this makes sense. i did try the similar thing before but seen divergence compared with casting from the torch side (not sure why, maybe i did it wrong). also, currently bfloat16 convergence test is not actually tested due to https://github.com/linkedin/Liger-Kernel/issues/176. after the fix is merged, we can try to run on convergence tests with bf16 to see if there is any gap.

ByronHsu avatar Sep 10 '24 00:09 ByronHsu

I added a test_float32_internal() unit test which runs the kernel twice, once with a bfloat16 input and once with a float-upcasted version, and verifies that the resulting output (in bfloat16) is exactly identical 🙂 You can also verify that the test passes even without the explicit .cast(tl.float32) calls, so maybe those could be removed as long as the test is present..

hansonw avatar Sep 10 '24 01:09 hansonw

cool! I will take a deeper look today or tomorrow. This is exciting!

ByronHsu avatar Sep 10 '24 17:09 ByronHsu

Can we merge this? in current form liger kernel broken.

kostum123 avatar Sep 30 '24 08:09 kostum123

@hansonw can we resolve the conflict? ty

lancerts avatar Oct 01 '24 23:10 lancerts

We can merge this once the conflict is resolved. thanks!!

ByronHsu avatar Oct 02 '24 20:10 ByronHsu

@ByronHsu I rebased this PR and fixed the conflicts, but not sure of the best way to help get the changes into this PR. see https://github.com/winglian/Liger-Kernel/pull/new/hw-fused-ce-float

winglian avatar Nov 21 '24 19:11 winglian

@pramodith would you interested in moving this forward? maybe create a new PR based on the current main. We just have to ensure the precision is correct

ByronHsu avatar Nov 21 '24 23:11 ByronHsu

@pramodith would you interested in moving this forward? maybe create a new PR based on the current main. We just have to ensure the precision is correct

@ByronHsu done #406

pramodith avatar Nov 22 '24 16:11 pramodith