lhotse icon indicating copy to clipboard operation
lhotse copied to clipboard

Got warnings when loading sampler's state_dict

Open csukuangfj opened this issue 2 years ago • 11 comments

I get the following warning while trying to use https://github.com/k2-fsa/icefall/pull/259 to restore the state dict of a sampler from a checkpoint.

lhotse/dataset/sampling/simple.py:144: UserWarning: SimpleCutSampler.load_state_dict():
 Inconsistent time_constraint:
expected TimeConstraint(max_duration=10, max_samples=None, max_frames=None, current=0, num_cuts=0)
received TimeConstraint(max_duration=10, max_samples=None, max_frames=None, current=32.968312499999996, num_cuts=2)

Related code is listed below: https://github.com/k2-fsa/icefall/blob/ae564f91e6981321a715d3ce1ddf5dec5cc21296/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py#L300

        if sampler_state_dict is not None:
            logging.info("Loading sampler state dict")
            train_sampler.load_state_dict(sampler_state_dict)

https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L427

    save_checkpoint_impl(
        filename=filename,
        model=model,
        params=params,
        optimizer=optimizer,
        sampler=sampler,
        rank=rank,
    )

https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L800

    if checkpoints and "sampler" in checkpoints:
        sampler_state_dict = checkpoints["sampler"]
    else:
        sampler_state_dict = None

    train_dl = librispeech.train_dataloaders(
        train_cuts, sampler_state_dict=sampler_state_dict
    )

csukuangfj avatar Mar 21 '22 01:03 csukuangfj

Also, I find that sampler's state_dict contains only the current epoch when the checkpoint was saved. It does not say at which batch in the current epoch the checkpoint was saved.

See https://github.com/lhotse-speech/lhotse/blob/b3f219407438b86d1a23f8d47f60f55b7709d1d9/lhotse/dataset/sampling/base.py#L132-L139

https://github.com/lhotse-speech/lhotse/blob/b3f219407438b86d1a23f8d47f60f55b7709d1d9/lhotse/dataset/sampling/simple.py#L112-L119

https://github.com/lhotse-speech/lhotse/blob/b3f219407438b86d1a23f8d47f60f55b7709d1d9/lhotse/dataset/sampling/bucketing.py#L217-L230


As a result, I have to use https://github.com/k2-fsa/icefall/blob/ae564f91e6981321a715d3ce1ddf5dec5cc21296/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L618

    cur_batch_idx = params.get("cur_batch_idx", 0)

    for batch_idx, batch in enumerate(train_dl):
        if batch_idx < cur_batch_idx:
            continue
        cur_batch_idx = batch_idx

to skip specified number of batches when resuming training from a checkpoint, which may take several minutes.

csukuangfj avatar Mar 21 '22 01:03 csukuangfj

You're right that it doesn't store batch_idx; instead it stored the number of cuts that were already processed (it's inside diagnostics). It's sufficient for the sampler to correctly restore its state, and you don't need your workaround that requires long waiting times.

pzelasko avatar Mar 21 '22 20:03 pzelasko

As to the warning I'll take another look at it later, it probably shouldn't compare the fields current and num_cuts.

pzelasko avatar Mar 21 '22 20:03 pzelasko

Actually I was wrong -- both the number of cuts and batches that was consumed is kept in diagnostics, so you can read out everything. However the diagnostics are being reset after every epoch -- but I think it makes sense to change it so that they keep accumulating. WDYT?

Also check this PR which addresses some of your other comments https://github.com/lhotse-speech/lhotse/pull/632

pzelasko avatar Mar 21 '22 20:03 pzelasko

It's sufficient for the sampler to correctly restore its state, and you don't need your workaround that requires long waiting times.

The issue is that the batch_idx in the following line starts from 0 even if we resume training from a checkpoint. That may cause confusions for users as it seems that it does not pick up the location where it was previously saved.

https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless/train.py#L620

    for batch_idx, batch in enumerate(train_dl):

csukuangfj avatar Mar 22 '22 02:03 csukuangfj

you’d need to do sth like:

for batch_idx, batch in enumerate(train_dl, start=train_dl.sampler.diagnostics.num_batches_kept)

You will encounter two issues with that:

  1. the diagnostics is reset after each epoch, but I will push a fix for that.
  2. there is going to be a difference in batch_idx equal to num_workers * prefetch_factor - 1, because in the script that saved the state dict, dataloader workers already „consumed” the cutset batch from the sampler, but it was not actually used in training „yet”. I am not sure how we could work around it, but it’s probably not a big issue for large datasets.

pzelasko avatar Mar 22 '22 12:03 pzelasko

you’d need to do sth like:

Thanks!

csukuangfj avatar Mar 22 '22 12:03 csukuangfj

for batch_idx, batch in enumerate(train_dl, start=train_dl.sampler.diagnostics.num_batches_kept)

There is only https://github.com/lhotse-speech/lhotse/blob/3685d8c6fc8f4e3c773ac4e851b2265f38c05115/lhotse/dataset/sampling/base.py#L387 and there is no num_batches_kept.

Also, I find that train_dl.sampler.diagnostics.num_kept_batches is always 0 during training, at least for the first several batches. Is that expected?

csukuangfj avatar Mar 23 '22 04:03 csukuangfj

Also, if I use start in enumerate, I think it changes only batch_idx, but it still returns the 0-th element from the dataloader.

>>> a = list(range(10))
>>> a
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> for i, v in enumerate(a, start=3): print(i, v)
...
3 0
4 1
5 2
6 3
7 4
8 5
9 6
10 7
11 8
12 9

csukuangfj avatar Mar 23 '22 04:03 csukuangfj

I am also noticing the loss being better after we re-load from a checkpoint, aroun 0.05->0.04. I suspect something about the SpecAugment settings may have changed. It is not just a transient issue, it stays lower. I'm a bit concerned that this feature might be a bug farm.

danpovey avatar Mar 23 '22 05:03 danpovey

Yes I think you’re right about it potentially being a bug farm.

I looked into it a bit yesterday, BucketingSampler and ZipSampler need some fixes to support it correctly. Regarding SpecAugment I don’t know what could have changed — other than the RNG state being different than at the point that the training stopped.

I will commit a fix later that should resolve the bucketing sampler issue.

pzelasko avatar Mar 23 '22 11:03 pzelasko