Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Reasons for upcasting the logits dtype outside the kernel

Open yzhangcs opened this issue 1 year ago • 7 comments
trafficstars

Hello, thank you for this great work. https://github.com/linkedin/Liger-Kernel/blob/acd82728207ebafad28d448640502c108901a967/src/liger_kernel/ops/fused_linear_cross_entropy.py#L69

https://github.com/linkedin/Liger-Kernel/blob/acd82728207ebafad28d448640502c108901a967/src/liger_kernel/ops/fused_linear_cross_entropy.py#L91-L96

I'm wondering if there are any reasons for upcasting/downcasting the logits dtype outside the kernel? If I understand correctly, we already do fp32 upcast inside, so this op is redundant? I just compare the outputs of the two versions, i.e., w/ and w/o the upcast, and found there's no precision loss if the above code r removed.

yzhangcs avatar Sep 10 '24 17:09 yzhangcs