returnn
returnn copied to clipboard
Ignore a single broken gradient
In my current language model training I sometimes get "nan" gradients, which break the training. Surprisingly, just restarting the training from the last checkpoint is often enough uncertainty to resume training.
Here people discussed something like:
valid_gradients = True
for name, param in self.named_parameters():
if param.grad is not None:
valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any())
if not valid_gradients:
break
if not valid_gradients:
print(f'detected inf or nan values in gradients. not updating model parameters')
self.zero_grad()
I think it would be a good idea to have this as a configurable option for the updater. Preferably with a "limit", so that it still crashes after e.g. 5 broken updates.