icefall icon indicating copy to clipboard operation
icefall copied to clipboard

Improve diagnostics for when models diverge

Open danpovey opened this issue 2 years ago • 1 comments

Guys, Occasionally, as I tune my latest version of the new setup, I get divergence. It looks like this:

<trains normally up until now, then:>
2022-11-20 01:34:30,723 INFO [train.py:916] Epoch 20, batch 850, loss[loss=0.1723, simple_loss=0.2561, pruned_loss=0.04421, over 14772.00 frames. ], tot_loss[loss=0.1731, simple_loss=0.2563, pruned_loss=0.04491, over 2919430.18 frames.\
 ], batch size: 35, lr: 7.94e-03, grad_scale: 0.0009765625
2022-11-20 01:34:40,945 WARNING [optim.py:359] Scaling gradients by 0.010860392823815346, model_norm_threshold=741.3968505859375
2022-11-20 01:34:45,373 WARNING [optim.py:359] Scaling gradients by 0.0011294195428490639, model_norm_threshold=741.3968505859375
2022-11-20 01:34:46,461 WARNING [optim.py:359] Scaling gradients by 5.275502917356789e-05, model_norm_threshold=741.3968505859375
2022-11-20 01:34:48,423 WARNING [optim.py:359] Scaling gradients by 0.06526870280504227, model_norm_threshold=741.3968505859375
2022-11-20 01:34:50,249 INFO [scaling.py:615] Whitening: num_groups=1, num_channels=384, metric=21.22 vs. limit=5.0
2022-11-20 01:34:50,878 WARNING [optim.py:359] Scaling gradients by 2.0925328499288298e-05, model_norm_threshold=741.3968505859375
2022-11-20 01:34:55,214 WARNING [optim.py:359] Scaling gradients by 0.08923347294330597, model_norm_threshold=741.3968505859375
2022-11-20 01:34:56,287 WARNING [optim.py:359] Scaling gradients by 1.3554388260672567e-06, model_norm_threshold=741.3968505859375
2022-11-20 01:34:59,388 WARNING [optim.py:359] Scaling gradients by 6.9577804424625356e-06, model_norm_threshold=741.3968505859375
2022-11-20 01:35:00,574 WARNING [optim.py:359] Scaling gradients by 0.07936681807041168, model_norm_threshold=741.3968505859375
2022-11-20 01:35:06,339 WARNING [optim.py:359] Scaling gradients by 0.015136143192648888, model_norm_threshold=741.3968505859375
2022-11-20 01:35:07,441 WARNING [optim.py:359] Scaling gradients by 7.634332064299088e-07, model_norm_threshold=741.3968505859375
2022-11-20 01:35:08,455 WARNING [optim.py:359] Scaling gradients by 0.08001881092786789, model_norm_threshold=741.3968505859375
2022-11-20 01:35:12,603 WARNING [optim.py:359] Scaling gradients by 0.031115038320422173, model_norm_threshold=741.3968505859375
2022-11-20 01:35:13,623 WARNING [optim.py:359] Scaling gradients by 0.0533587820827961, model_norm_threshold=741.3968505859375
2022-11-20 01:35:17,439 WARNING [optim.py:359] Scaling gradients by 0.0014232676476240158, model_norm_threshold=741.3968505859375
2022-11-20 01:35:18,464 WARNING [optim.py:359] Scaling gradients by 0.07824313640594482, model_norm_threshold=741.3968505859375
2022-11-20 01:35:19,205 INFO [scaling.py:615] Whitening: num_groups=1, num_channels=384, metric=114.88 vs. limit=10.0
2022-11-20 01:35:19,563 WARNING [optim.py:359] Scaling gradients by 0.03925001621246338, model_norm_threshold=741.3968505859375
2022-11-20 01:35:20,525 WARNING [optim.py:359] Scaling gradients by 0.09417390078306198, model_norm_threshold=741.3968505859375
2022-11-20 01:35:23,953 WARNING [train.py:908] Grad scale is small: 1.9073486328125e-06
Traceback (most recent call last):
  File "./pruned_transducer_stateless7/train.py", line 1244, in <module>
    main()
  File "./pruned_transducer_stateless7/train.py", line 1237, in main
    run(rank=0, world_size=1, args=args)
  File "./pruned_transducer_stateless7/train.py", line 1117, in run
    train_one_epoch(
  File "./pruned_transducer_stateless7/train.py", line 910, in train_one_epoch
    raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
RuntimeError: grad_scale is too small, exiting: 1.9073486328125e-06

The "scaling gradients by" message comes from the optimizer, it is detecting very large (much larger than normal) gradients and the gradient clipping mechanism is being activated. What I want to do is to discover, when this happens, whether a particular model parameter is dominating the "big gradients". E.g. to print out the name of the parameter that has the largest contribution to the "tot_sumsq" metric and what proportion of the "tot_sumsq" metric it constitutes. And I want to do this in such a way that we don't incur any extra CPU<->GPU transfers, or ideally, any extra code, in the "non-error" case.

This will require storing some extra information in the optimizer object. Right now it takes model.parameters() as an arg. It would have to be changed to [possibly optionally] take model.named_parameters() as an arg, store the names in some way, and then provide just the parameters to the base-class constructor. You will have to understand how the BatchedOptimizer base class works, to implement this. It's OK to change that base class also, if necessary, but try to keep the design elegant.

Any extra work that needs to be done for each minibatch, I'd like to be done, ideally, inside this if-statement:

            if ans < 0.1:
                logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")  

so that it will only cause extra work to be done if we have already detected "something bad is happening". Also, please segregate the extra code that is required at this point into a separate function call so that there is good separation between normal-case and debugging code.

Something else to consider is that the "param_rms" value can become "out of date", it's only recomputed every 10 iterations or so and if the model is diverging the difference may be significant, so it may be worthwhile, once we detect a problem, to recompute a more-accurate version of the param_rms.

danpovey avatar Nov 20 '22 03:11 danpovey

Also, please segregate the extra code that is required at this point into a separate function call so that there is good separation between normal-case and debugging code.

May I have two questions? Q1: Does the new added code only calculate and log some diagnostic information? The training procedure is not affected.

Q2: Does this extra function work like this?

            if ans < 0.1:
                logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")  
               if params.debugging:   # could be disabled in normal-case
                     show_gradient_dominating_parameter() # find name and proportion of target parameters according to its contribution to tot_sumsq.

glynpu avatar Nov 21 '22 14:11 glynpu