torchtitan
torchtitan copied to clipboard
[Feature] Support skipping bad grad updates
Hi, I'm wondering if it's ok to support skipping inf/nan grads in cases of encountering bad data. sth like this
if job_config.training.skip_nan_inf and (
grad_norm.isnan() or grad_norm.isinf()
):
logger.warning(
f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}"
)
optimizers.zero_grad()
train_state.skipped_step += 1
else:
optimizers.step()
lr_schedulers.step()
very glad to create a PR if you think it's necessary @tianyu-l.
How do we know if nan/inf is not caused by bad training/modeling/hyper-parameters? Would it be better that the training should stop when encountering bad loss and let the model author further debug?
I have two comments:
- I was going to say the same thing as @fegin pointed out -- it potentially can make it hard to debug with this option on. Could you share some scenarios where skipping is the proper behavior? In the case of bad data, maybe it's better for whoever using the bad data to put a warning and skipping in the code, rather than putting it in torchtitan.
- One can achieve this without modifying
train.py, by doing the check and skip in a customized optimizer, inside thestep()function call.
@fegin @tianyu-l I’d like to share a few scenarios where skipping gradient updates might be beneficial. For instance, when working with data from multiple fields or datasets that aren’t perfectly filtered, we might encounter gradients with extremely large norms or invalid values (e.g., inf/nan). Applying such gradients directly during optimization could harm the model’s performance. In these cases, skipping the update for those specific steps could help maintain training stability and final model quality. Typically, addressing this issue would require manually shutting down the training process, debugging, and restarting, which can be time-consuming.
Adding an option to automatically skip steps based on gradient norms (e.g., inf/nan detection or thresholding) could streamline this process. To ensure this doesn’t complicate debugging, we could introduce a monitoring metric like skipped_batches to track how often skips occur. Ideally, skipped_batches would increase only rarely, indicating that the training process is handling unavoidable edge cases.
If it increases rapidly, it could signal a deeper issue (e.g., a training crash), prompting immediate investigation.
I also noticed that the InternLM project employs a similar approach, which suggests that this feature could be useful in practice. So i'm wondering if you do need this, that's why I'm raise this issue for some discussions.
Thank you.
One can achieve this without modifying train.py, by doing the check and skip in a customized optimizer, inside the step() function call.
Yes, of course this is an alternaive way!
just wanted to add that if we do want to support this, then longer term it may be a lot more performant to re-purpose Ke's CUDA kernel as the nan checker... that is much faster than a simple check written in PyTorch eager. (cc @kwen2501 )