icefall icon indicating copy to clipboard operation
icefall copied to clipboard

on-the-fly fbank feats

Open Cescfangs opened this issue 1 year ago • 17 comments

Hey guys, I notice there‘s on-the-fly feats in asr_datamodule.py: https://github.com/k2-fsa/icefall/blob/32de2766d591d2e1a77c06a40d2861fb1bbcd3ad/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py#L279-L298

However, I didn't find any recipe using that feats, how could I using on-the-fly feats instead of making fbank first(I'm using large dataset, making fbank locally would not be possible due to disk capacity ). Btw, does icefall support reading feats.scp directory like espnet does?(we have precomputed kaldi fbank)

Cescfangs avatar Nov 08 '22 08:11 Cescfangs

Pass

--on-the-fly-feats true

to train.py will do.

You can use

train.py --help

to view the usages.

csukuangfj avatar Nov 08 '22 08:11 csukuangfj

Thanks for the reply, so I could just skip making fbank stages in prepare.sh?

Cescfangs avatar Nov 08 '22 08:11 Cescfangs

Yes, I believe so.

csukuangfj avatar Nov 08 '22 08:11 csukuangfj

Thanks mate.

Cescfangs avatar Nov 08 '22 08:11 Cescfangs

@csukuangfj I found on-the-fly feats computation makes training much slower, for example it cost 20 seconds using pre computed kaldi fbank feats for 50 batch iteration and it took about 4 minutes by on-the-fly computation under the same circumstance, I notice you have trained with on-the-fly feats on large datasets(https://github.com/k2-fsa/icefall/pull/312#issuecomment-1096641908), how did you resolve this problem?

Cescfangs avatar Aug 08 '23 10:08 Cescfangs

Are you using raw waves? Also, is your disk fast?

csukuangfj avatar Aug 08 '23 11:08 csukuangfj

Are you using raw waves? Also, is your disk fast?

Yes I'm using raw waves and how to check my disk is fast or slow?

Cescfangs avatar Aug 08 '23 11:08 Cescfangs

BTW, I've trained using raw waves with Espnet, the gpu utility is around 70% which I think is normal , the difference is in Espnet I implement Fbank as a frontend layer(part of model, and running on GPU), so maybe my disk is not the bottleneck?

Cescfangs avatar Aug 08 '23 11:08 Cescfangs

Can you try increasing the number of dataloader workers? Perhaps that’s the bottleneck.

If you want to use fbank as a layer you can modify the code to use https://github.com/lhotse-speech/lhotse/blob/eb9e6b115729697c66c0a7f5f7ba08984b6a1ee5/lhotse/features/kaldi/layers.py#L476

If it turns out to be a slow disk problem you can speed up the IO at the cost of extra copy of data using: https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb

pzelasko avatar Aug 08 '23 11:08 pzelasko

Can you try increasing the number of dataloader workers? Perhaps that’s the bottleneck.

If you want to use fbank as a layer you can modify the code to use https://github.com/lhotse-speech/lhotse/blob/eb9e6b115729697c66c0a7f5f7ba08984b6a1ee5/lhotse/features/kaldi/layers.py#L476

If it turns out to be a slow disk problem you can speed up the IO at the cost of extra copy of data using: https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb

@pzelasko thanks for the advice, I have done some experiments:

feature_type dataloader workers OnTheFlyFeatures workers runtime(50 batch)
precompute fbank 4 - 20s
KaldifeatFbank 4 0 240s
KaldifeatFbank 4 4 140s
+lhotse.set_caching_enabled(True) 4 4 130s
8 4 80s
8 8 80s
16 4 50s

I'm using KaldifeatFbank because it's compatible to Kaldi

Cescfangs avatar Aug 08 '23 12:08 Cescfangs

I'm using KaldifeatFbank because it's compatible to Kaldi

Can you try to use GPU to extract features?

KaldifeatFbank supports GPU. If you are using DDP, you can use device="cuda:0", device="cuda:1", etc., to specify the device.

csukuangfj avatar Aug 08 '23 12:08 csukuangfj

You are getting the best gains by increasing dataloader workers so it’s likely an IO bottleneck, using webdataset or Lhotse Shar may help.

BTW the fbank I posted is also compatible with Kaldi. Note that regardless which one you choose, you’ll need to move fbank computation from dloader to training loop to leverage GPU.

pzelasko avatar Aug 08 '23 13:08 pzelasko

@pzelasko I got errors when increase word_size from 1 to 8, can you give me some advice?

"2023-08-09T09:46:00+08:00" malloc(): invalid size (unsorted)
"2023-08-09T09:46:00+08:00" malloc(): invalid size (unsorted)
"2023-08-09T09:46:01+08:00" malloc(): invalid size (unsorted)
"2023-08-09T09:46:02+08:00" malloc(): invalid size (unsorted)
"2023-08-09T09:46:03+08:00" Traceback (most recent call last):
"2023-08-09T09:46:03+08:00"   File "./pruned_transducer_stateless5_bs/train.py", line 1475, in <module>
"2023-08-09T09:46:03+08:00"     main()
"2023-08-09T09:46:03+08:00"   File "./pruned_transducer_stateless5_bs/train.py", line 1464, in main
"2023-08-09T09:46:03+08:00"     mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
"2023-08-09T09:46:03+08:00"     return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
"2023-08-09T09:46:03+08:00"     while not context.join():
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
"2023-08-09T09:46:03+08:00"     raise ProcessRaisedException(msg, error_index, failed_process.pid)
"2023-08-09T09:46:03+08:00" torch.multiprocessing.spawn.ProcessRaisedException: 
"2023-08-09T09:46:03+08:00" 
"2023-08-09T09:46:03+08:00" -- Process 0 terminated with the following error:
"2023-08-09T09:46:03+08:00" Traceback (most recent call last):
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
"2023-08-09T09:46:03+08:00"     fn(i, *args)
"2023-08-09T09:46:03+08:00"   File "/data1/icefall-master/egs/hik/asr2_audio/pruned_transducer_stateless5_bs/train.py", line 1340, in run
"2023-08-09T09:46:03+08:00"     train_one_epoch(
"2023-08-09T09:46:03+08:00"   File "/data1/icefall-master/egs/hik/asr2_audio/pruned_transducer_stateless5_bs/train.py", line 1043, in train_one_epoch
"2023-08-09T09:46:03+08:00"     for batch_idx, batch in enumerate(train_dl):
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
"2023-08-09T09:46:03+08:00"     return self._get_iterator()
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 384, in _get_iterator
"2023-08-09T09:46:03+08:00"     return _MultiProcessingDataLoaderIter(self)
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1086, in __init__
"2023-08-09T09:46:03+08:00"     self._reset(loader, first_iter=True)
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1119, in _reset
"2023-08-09T09:46:03+08:00"     self._try_put_index()
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1353, in _try_put_index
"2023-08-09T09:46:03+08:00"     index = self._next_index()
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 642, in _next_index
"2023-08-09T09:46:03+08:00"     return next(self._sampler_iter)  # may raise StopIteration
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/base.py", line 261, in __next__
"2023-08-09T09:46:03+08:00"     batch = self._next_batch()
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/dynamic_bucketing.py", line 237, in _next_batch
"2023-08-09T09:46:03+08:00"     batch = next(self.cuts_iter)
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/dynamic_bucketing.py", line 360, in __iter__
"2023-08-09T09:46:03+08:00"     ready_buckets = [b for b in self.buckets if is_ready(b)]
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/dynamic_bucketing.py", line 360, in <listcomp>
"2023-08-09T09:46:03+08:00"     ready_buckets = [b for b in self.buckets if is_ready(b)]
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/dynamic_bucketing.py", line 351, in is_ready
"2023-08-09T09:46:03+08:00"     tot.add(c[0] if isinstance(c, tuple) else c)
"2023-08-09T09:46:03+08:00"   File "/data1/tools/lhotse-master/lhotse/dataset/sampling/base.py", line 350, in add
"2023-08-09T09:46:03+08:00"     self.current += cut.duration
"2023-08-09T09:46:03+08:00"   File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
"2023-08-09T09:46:03+08:00"     _error_if_any_worker_fails()
"2023-08-09T09:46:03+08:00" RuntimeError: DataLoader worker (pid 76863) is killed by signal: Aborted. 
"2023-08-09T09:46:03+08:00" 
"2023-08-09T09:46:03+08:00" [INFO] recv error: exit status 1
"2023-08-09T09:46:03+08:00" [ERROR] error happends during process: exit status 1
"2023-08-09T09:46:03+08:00" [INFO] still reserved
"2023-08-09T09:46:03+08:00" [INFO] recv flag (false)
"2023-08-09T09:46:03+08:00" [INFO] sleeping

Cescfangs avatar Aug 09 '23 02:08 Cescfangs

Can you reduce the number of workers (especially for on the fly features) and see if it helps?

pzelasko avatar Aug 09 '23 02:08 pzelasko

I'm using KaldifeatFbank because it's compatible to Kaldi

Can you try to use GPU to extract features?

KaldifeatFbank supports GPU. If you are using DDP, you can use device="cuda:0", device="cuda:1", etc., to specify the device.

@csukuangfj I tried using GPU for feature extraction, but it seems that we can't re-initialize CUDA in forked subprocess:

Original Traceback (most recent call last):
  File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = self.dataset[possibly_batched_index]
  File "/data1/tools/lhotse-master/lhotse/dataset/speech_recognition.py", line 113, in __getitem__
    input_tpl = self.input_strategy(cuts)
  File "/data1/tools/lhotse-master/lhotse/dataset/input_strategies.py", line 380, in __call__
    features_single = self.extractor.extract_batch(
  File "/data1/tools/lhotse-master/lhotse/features/kaldifeat.py", line 84, in extract_batch
    return self.extract(samples=samples, sampling_rate=sampling_rate)
  File "/data1/tools/lhotse-master/lhotse/features/kaldifeat.py", line 125, in extract
    result = self.extractor(samples, chunk_size=self.config.chunk_size)
  File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/miniconda3/lib/python3.8/site-packages/kaldifeat/offline_feature.py", line 79, in forward
    features = self.compute(strided, vtln_warp, chunk_size=chunk_size)
  File "/usr/local/miniconda3/lib/python3.8/site-packages/kaldifeat/offline_feature.py", line 135, in compute
    x[end:].to(self_device), vtln_warp
  File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/cuda/__init__.py", line 207, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Cescfangs avatar Aug 09 '23 02:08 Cescfangs

Can you reduce the number of workers (especially for on the fly features) and see if it helps?

Yes I can run at most 8 workers for dataloader and 1 worker for OnTheFlyFeatures setting world_size=8, the average runtime for 50 batch is around 95s, which seems very reasonable according to the table(https://github.com/k2-fsa/icefall/issues/666#issuecomment-1669542121).

feature_type dataloader workers OnTheFlyFeatures workers runtime(50 batch)
+lhotse.set_caching_enabled(True) 4 4 130s
  8 4 80s

You are getting the best gains by increasing dataloader workers so it’s likely an IO bottleneck, using webdataset or Lhotse Shar may help.

Is there any icefall recipe using webdataset or Lhotse Shar to follow?

Cescfangs avatar Aug 09 '23 03:08 Cescfangs

AFAIK there's no recipe at this time, but it shouldn't be too involved:

  • export your train cutset to Lhotse Shar format, e.g. with lhotse shar export --help
  • adjust the dataloading according to the example at the end of this notebook tutorial: https://github.com/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb

be aware that it will create a full copy of your audio data

pzelasko avatar Aug 09 '23 03:08 pzelasko