cryodrgn
cryodrgn copied to clipboard
Stop training on nan
cryodrgn train_vae should break if nan loss is hit instead of silently continuing for the specified number of epochs
@vineetbansal could you add this feature? Thanks!
The --do-pose-sgd option can frequently cause pose parameters to be nan. Because pose_optimizer.step() not scaled in automatic mixed precision training. I fixed my code as below:
if do_pose_sgd and epoch >= args.pretrain:
if args.amp:
optim.zero_grad()
scaler.step(pose_optimizer)
scaler.update()
else:
pose_optimizer.step()
if torch.any(torch.isnan(posetracker.rots_emb.weight[ind])) or torch.any(torch.isnan(posetracker.trans_emb.weight[ind])):
raise RuntimeError("NaN Found in Pose.")
The scaler.step can detect nan and skipped to avoid corrupting the params. It should be noted that this style is incompatible with apex.amp.