gpt-neox
gpt-neox copied to clipboard
Changing distillation weights changes runtime
Describe the bug It appears that imbalances in the distillation weights has a significant impact on performance. When I set them all equal to 1, it runs twice as fast as when I set lm_loss to 1 and the other two to 0.1.
To Reproduce
- Run a med-to-small distillation with weights 1, 1, 1
- Run med-to-small distillation with weights 1, 0.5, 0.5
- Run a med-to-small distillation with weights 1, 0.1, 0.1
Expected behavior I did not expect such great variation.
Proposed solution No clue
Screenshots

Environment (please complete the following information):
- GPUs: EleutherAI V100 cluster
- Configs: The defaults in the
distill-gpt-neoxbranch
Additional context Add any other context about the problem here.