Liger-Kernel
Liger-Kernel copied to clipboard
Fix dtype mismatch in fused_linear_cross_entropy_forward
trafficstars
Fixes #305
Fix dtype mismatch in fused_linear_cross_entropy_forward function.
- Cast
logits_chunkto the data type of_input_chunkbefore performing operations on it.
I tested this in Colab after the change and it solved the problem.
{ "epoch": 1.0, "eval_loss": 1.885668396949768, "eval_runtime": 0.1708, "eval_samples_per_second": 5.856, "eval_steps_per_second": 5.856, "total_flos": 1766475165597696.0, "train_loss": 1.9928909236309575, "train_runtime": 115.5799, "train_samples_per_second": 0.441, "train_steps_per_second": 0.441 }
For more details, open the Copilot Workspace session.