Liger-Kernel
Liger-Kernel copied to clipboard
Reasons for upcasting the logits dtype outside the kernel
trafficstars
Hello, thank you for this great work. https://github.com/linkedin/Liger-Kernel/blob/acd82728207ebafad28d448640502c108901a967/src/liger_kernel/ops/fused_linear_cross_entropy.py#L69
https://github.com/linkedin/Liger-Kernel/blob/acd82728207ebafad28d448640502c108901a967/src/liger_kernel/ops/fused_linear_cross_entropy.py#L91-L96
I'm wondering if there are any reasons for upcasting/downcasting the logits dtype outside the kernel? If I understand correctly, we already do fp32 upcast inside, so this op is redundant? I just compare the outputs of the two versions, i.e., w/ and w/o the upcast, and found there's no precision loss if the above code r removed.