k2 icon indicating copy to clipboard operation
k2 copied to clipboard

Test CTC/LF-MMI interpolation

Open danpovey opened this issue 4 years ago • 57 comments

This is fairly easy to do: in the current MMI script, just add a scaling factor on the denominator likelood, like 0.9 or 0.8 or something.

danpovey avatar Jan 07 '21 08:01 danpovey

Would be nice if @jimbozhang or @pzelasko could help.

danpovey avatar Jan 07 '21 08:01 danpovey

Is it time enough if I test it next week? Or maybe @pzelasko has time to do this. :neutral_face:

jimbozhang avatar Jan 07 '21 09:01 jimbozhang

OK I can check this out

pzelasko avatar Jan 07 '21 19:01 pzelasko

In the LF-MMI training, in epoch 2 I see an error:

Traceback (most recent call last):
  File "mmi_bigram_train.py", line 428, in <module>
    main()
  File "mmi_bigram_train.py", line 386, in main
    global_batch_idx_valid=global_batch_idx_valid)
  File "mmi_bigram_train.py", line 201, in train_one_epoch
    get_objf(batch, model, P, device, graph_compiler, True, optimizer)
  File "mmi_bigram_train.py", line 144, in get_objf
    (-tot_score).backward()
  File "/home/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

The GPU is GTX 1080Ti, details about python/torch and K2:

conda list | grep torch
pytorch                   1.7.1           py3.7_cuda10.2.89_cudnn7.6.5_0    pytorch

pip list | grep k2
k2                     0.1.2+cu102.dev20210107

which options/env vars should I set to see a more detailed error trace?

pzelasko avatar Jan 07 '21 21:01 pzelasko

Anyway, from the first 2 epochs, it seems like scaling the denominator down does not help. I'm adding tensorboard screenshot - it's weirdly rendered likely because each subsequent epoch resets the batch index to zero, so it looks a bit like a ping-pong trajectory (easy to fix). Green - baseline, Blue - den scale 0.9, Red - den scale 0.8

image

I'm also gonna try with scales 1.1 and 1.2.

For full reference, I apply it like:

num_tot_scores = k2.get_tot_scores(num,
                                       log_semiring=True,
                                       use_double_scores=True)
den_tot_scores = k2.get_tot_scores(den,
                                       log_semiring=True,
                                       use_double_scores=True)
tot_scores = num_tot_scores - den_scale * den_tot_scores

pzelasko avatar Jan 07 '21 21:01 pzelasko

RE tensorboard:

Screen Shot 2021-01-08 at 06 10 26

Please select RELATIVE.

csukuangfj avatar Jan 07 '21 22:01 csukuangfj

RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

Probably it's not an error from k2. I just googled it and found there are lots of information about it: https://www.google.com/search?newwindow=1&sxsrf=ALeKk02HYjLsBWP9U5y_meNArr1OlbiiWg%3A1610057571186&source=hp&ei=Y4f3X6umCNLxkwXljbzAAQ&q=RuntimeError%3A+cuDNN+error%3A+CUDNN_STATUS_EXECUTION_FAILED&oq=RuntimeError%3A+cuDNN+error%3A+CUDNN_STATUS_EXECUTION_FAILED&gs_lcp=CgZwc3ktYWIQA1DiBVjiBWD2BWgAcAB4AIABAIgBAJIBAJgBAKABAqABAaoBB2d3cy13aXo&sclient=psy-ab&ved=0ahUKEwjr4-Hs64ruAhXS-KQKHeUGDxgQ4dUDCAk&uact=5

csukuangfj avatar Jan 07 '21 22:01 csukuangfj

Thanks. So that's what the "relative" option is for ;) updated screens:

image

Re CUDNN: you're right, I was probably too quick to assume K2 error. I'll try to figure it out, but just to make sure, you encountered no such error?

EDIT: now that I think of it, it might be something related to our grid, I had the same issue in the past with Espresso LSTM models... so nvm :)

pzelasko avatar Jan 08 '21 00:01 pzelasko

but just to make sure, you encountered no such error?

I did not encounter this problem while running the training script from the master.

csukuangfj avatar Jan 08 '21 00:01 csukuangfj

I have met this same error on asg_train.py with aishell recipe training.

Traceback (most recent call last):
  File "asg_train.py", line 428, in <module>
    main()
  File "asg_train.py", line 375, in main
    objf = train_one_epoch(dataloader=train_dl,
  File "asg_train.py", line 197, in train_one_epoch
    get_objf(batch, model, P, device, graph_compiler, True, optimizer)
  File "asg_train.py", line 140, in get_objf
    (-tot_score).backward()
  File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

fanlu avatar Jan 08 '21 02:01 fanlu

There are many upvotes for the following solution: https://github.com/NVIDIA/tacotron2/issues/109#issuecomment-471940890

Not sure whether it works as I could not reproduce the problem.

csukuangfj avatar Jan 08 '21 02:01 csukuangfj

This error occurred around 1w4 iterations in training procedure, the iterations is so large because I have changed the max_frames from 90000 to 10000 to avoid the OOM error. My torch and cuda env version match.

conda list | grep torch
torch                     1.7.0+cu101              pypi_0    pypi
torchaudio                0.7.0                    pypi_0    pypi

pip list | grep k2
k2                     0.1.2.dev20210105

nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Fri_Feb__8_19:08:17_PST_2019
Cuda compilation tools, release 10.1, V10.1.105

fanlu avatar Jan 08 '21 02:01 fanlu

The suggested solution uses pip, not conda. I am also not using conda. Don't know if there are any problems with PyTorch installed with conda.

csukuangfj avatar Jan 08 '21 03:01 csukuangfj

Unfortunately, not using conda is a luxury that some of us (who don't have sudo privileges/can't mess with their compute infrastructure) cannot afford... Maybe it's possible to do a hybrid conda/pip installation of pytorch but I have not tried yet.

pzelasko avatar Jan 08 '21 03:01 pzelasko

If you are using release mode, doing export K2_SYNC_KERNELS=1 before running may help-- it's possible it is an error from k2 code that was not caught as it was the last k2 kernel to run before CUDNN. That error you are seeing is from the Torch part of the code so I dont think there is anything we can do from k2 itself to affect tracing.

On Fri, Jan 8, 2021 at 11:17 AM Piotr Żelasko [email protected] wrote:

Unfortunately, not using conda is a luxury that some of us (who don't have sudo privileges/can't mess with their compute infrastructure) cannot afford... Maybe it's possible to do a hybrid conda/pip installation of pytorch but I have not tried yet.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/k2/issues/572#issuecomment-756521116, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO7ESC5Q2SIA4UWJJMLSYZ2OLANCNFSM4VYUPM2A .

danpovey avatar Jan 08 '21 03:01 danpovey

Unfortunately, not using conda is a luxury that some of us

I am using pyenv, which does not require sudo permissions. It can manage different python versions for you without installing extra things.


If you are using release mode, doing export K2_SYNC_KERNELS=1

I am afraid it works only in Debug mode.

csukuangfj avatar Jan 08 '21 04:01 csukuangfj

OK. Will have to compile in debug mode then. I suspect this is a previous error or an issue with sizes or something, rather than an installation problem. I suspect with den-scale > 1.0 it may diverge or something else nasty.

danpovey avatar Jan 08 '21 04:01 danpovey

It’s not only about the Python version, it also allows you to install native dependencies (including CUDA, MKL, CUDNN, even gcc) in versions you want in a very simple way without sudo. I guess it’s all doable without conda but probably not worth the time investment... anyway it’s not a super big issue, the training can be restarted from a checkpoint (maybe we need to save them after N steps rather than after each epoch).

Wiadomość napisana przez Fangjun Kuang [email protected] w dniu 1/7/21, o godz. 23:20:

 Unfortunately, not using conda is a luxury that some of us

I am using pyenv, which does not require sudo permissions. It can manage different python versions for you without installing extra things.

If you are using release mode, doing export K2_SYNC_KERNELS=1

I am afraid it works only in Debug mode.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

pzelasko avatar Jan 08 '21 05:01 pzelasko

use the latest code with Debug mode,

cmake -DCMAKE_BUILD_TYPE=Debug ..
make -j

the problem below occured

nvlink error   : Entry function '_ZN3cub31DeviceRadixSortSingleTileKernelINS_21DeviceRadixSortPolicyIiiiE9Policy700ELb0EiiiEEvPKT1_PS4_PKT2_PS8_T3_ii' uses too much shared data (0xdc50 bytes, 0xc000 max) (target: sm_50)
nvlink error   : Entry function '_ZN3cub30DeviceRadixSortDownsweepKernelINS_21DeviceRadixSortPolicyIiiiE9Policy700ELb0ELb0EiiiEEvPKT2_PS4_PKT3_PS8_PT4_SC_iiNS_13GridEvenShareISC_EE' uses too much shared data (0xfab0 bytes, 0xc000 max) (target: sm_50)
nvlink error   : Entry function '_ZN3cub23RadixSortScanBinsKernelINS_21DeviceRadixSortPolicyIiiiE9Policy700EiEEvPT0_i' uses too much shared data (0x11020 bytes, 0xc000 max) (target: sm_50)
nvlink error   : Entry function '_ZN3cub30DeviceRadixSortDownsweepKernelINS_21DeviceRadixSortPolicyIiiiE9Policy700ELb1ELb0EiiiEEvPKT2_PS4_PKT3_PS8_PT4_SC_iiNS_13GridEvenShareISC_EE' uses too much shared data (0xe0b0 bytes, 0xc000 max) (target: sm_50)
make[2]: *** [k2/csrc/CMakeFiles/context.dir/cmake_device_link.o] 错误 255
make[1]: *** [k2/csrc/CMakeFiles/context.dir/all] 错误 2
make: *** [all] 错误 2

fanlu avatar Jan 08 '21 05:01 fanlu

the problem below occured

Please either upgrade your NVCC from 10.1.105 to 10.1.243 or re-pull the latest k2.

csukuangfj avatar Jan 08 '21 07:01 csukuangfj

I have build the latest code with Debug mode, and train asg_train.py with K2_SYNC_KERNELS=1, but the problem RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED do not disappear.

export K2_SYNC_KERNELS=1;export CUDA_VISIBLE_DEVICES=2;PYTHONPATH=/mnt/cfs1_alias1/asr/users/fanlu/task/snowfall:$PYTHONPATH; python asg_train.py

fanlu avatar Jan 08 '21 10:01 fanlu

Try running it with cuda-memcheck.

On Fri, Jan 8, 2021 at 6:06 PM fanlu [email protected] wrote:

I have build the latest code with Debug mode, and train asg_train.py with K2_SYNC_KERNELS=1, but the problem RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED do not disappear.

export K2_SYNC_KERNELS=1;export CUDA_VISIBLE_DEVICES=2;PYTHONPATH=/mnt/cfs1_alias1/asr/users/fanlu/task/snowfall:$PYTHONPATH; python asg_train.py

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/k2/issues/572#issuecomment-756667863, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLOZBRDIZFUGMLUNUTKDSY3KI5ANCNFSM4VYUPM2A .

danpovey avatar Jan 08 '21 10:01 danpovey

Some results:

den_scale = 0.8

2021-01-08 12:22:01,328 INFO [mmi_bigram_decode.py:297] %WER 12.41% [6525 / 52576, 872 ins, 732 del, 4921 sub ]

den_scale = 0.9

2021-01-08 17:26:31,018 INFO [mmi_bigram_decode.py:297] %WER 13.15% [6912 / 52576, 1056 ins, 764 del, 5092 sub ]

The reference training (den_scale = 1.0) went wrong, so I'm re-running it from the start (got 47.48% WER for some reason) and I'll report the number when I have it. I believe @csukuangfj was getting about 11~11.5%?

pzelasko avatar Jan 08 '21 22:01 pzelasko

The best WER I got was 11.19% using the current master branch and the checkpoint can be downloaded from the following link: https://drive.google.com/file/d/1pIElpm6Ij8YzOEh_GpiBlun5BwryhzPj/view?usp=sharing

Hope that you can reproduce it locally.

csukuangfj avatar Jan 09 '21 00:01 csukuangfj

Yeah, I reproduced with:

2021-01-09 19:24:38,925 INFO [mmi_bigram_decode.py:297] %WER 11.09% [5831 / 52576, 985 ins, 453 del, 4393 sub ]

However, that's only after I removed old L, G, etc. and re-run the whole thing from scratch (otherwise, for two runs I was somehow getting 35 and 44% WER). I will now re-run the den_scale expts, as the previous setup might have had some mismatch.

pzelasko avatar Jan 10 '21 01:01 pzelasko

I've got the results for the new training runs:

den_scale = 0.9

2021-01-10 15:50:54,008 INFO [mmi_bigram_decode.py:297] %WER 11.42% [6005 / 52576, 661 ins, 817 del, 4527 sub ]

den_scale = 0.8

2021-01-10 10:39:58,444 INFO [mmi_bigram_decode.py:297] %WER 11.20% [5890 / 52576, 706 ins, 747 del, 4437 sub ]

They seem marginally worse than the no-interpolation training (11.09% WER), what's interesting is that 0.8 is better than 0.9. so maybe even lower weights could help. I'm not sure how much %WER spread we get when training the same model due to randomness, it seems at least 0.1% since @csukuangfj had 11.19% WER with this set up. Hard to claim significance with such marginal changes in scores.

pzelasko avatar Jan 10 '21 20:01 pzelasko

OK, thanks. One reason we might eventually want to try den_scale less than 1 is to encourage sparsity in the non-blank symbols, to make the lattices sparser and the search faster. No urgency on that though. Dan

On Mon, Jan 11, 2021 at 4:57 AM Piotr Żelasko [email protected] wrote:

I've got the results for the new training runs:

den_scale = 0.9

2021-01-10 15:50:54,008 INFO [mmi_bigram_decode.py:297] %WER 11.42% [6005 / 52576, 661 ins, 817 del, 4527 sub ]

den_scale = 0.8

2021-01-10 10:39:58,444 INFO [mmi_bigram_decode.py:297] %WER 11.20% [5890 / 52576, 706 ins, 747 del, 4437 sub ]

They seem marginally worse than the no-interpolation training (11.09% WER), what's interesting is that 0.8 is better than 0.9. so maybe even lower weights could help. I'm not sure how much %WER spread we get when training the same model due to randomness, it seems at least 0.1% since @csukuangfj https://github.com/csukuangfj had 11.19% WER with this set up. Hard to claim significance with such marginal changes in scores.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/k2/issues/572#issuecomment-757542972, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO4BN2OAGQFC4FMR4A3SZIIFBANCNFSM4VYUPM2A .

danpovey avatar Jan 11 '21 03:01 danpovey

Try running it with cuda-memcheck. On Fri, Jan 8, 2021 at 6:06 PM fanlu @.***> wrote: I have build the latest code with Debug mode, and train asg_train.py with K2_SYNC_KERNELS=1, but the problem RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED do not disappear. export K2_SYNC_KERNELS=1;export CUDA_VISIBLE_DEVICES=2;PYTHONPATH=/mnt/cfs1_alias1/asr/users/fanlu/task/snowfall:$PYTHONPATH; python asg_train.py — You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub <#572 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLOZBRDIZFUGMLUNUTKDSY3KI5ANCNFSM4VYUPM2A .

@danpovey @csukuangfj I have used cuda-memcheck to check memory usage. Maybe there is a memory error in k2 when training long sequence. I am not sure.

Traceback (most recent call last):
  File "mmi_bigram_train.py", line 141, in get_objf
    (-tot_score).backward()
  File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
Traceback (most recent call last):
  File "mmi_bigram_train.py", line 431, in <module>
    main()
  File "mmi_bigram_train.py", line 378, in main
    objf = train_one_epoch(dataloader=train_dl,
  File "mmi_bigram_train.py", line 201, in train_one_epoch
    get_objf(batch, model, P, device, graph_compiler, True, optimizer)
  File "mmi_bigram_train.py", line 97, in get_objf
    nnet_output = model(feature)
  File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/mnt/cfs1_alias1/asr/users/fanlu/task/snowfall/snowfall/models/tdnn_lstm.py", line 156, in forward
    x_new, _ = lstm(x)
  File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 581, in forward
    result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
========= LEAK SUMMARY: 0 bytes leaked in 0 allocations
========= ERROR SUMMARY: 1 error

fanlu avatar Jan 11 '21 07:01 fanlu

Mm. It may be possible to run with valgrind on CPU and it would find the error, but possibly the error would just disappear on CPU or would take too long to reach. Does it seem repeatable, e.g. at a particular minibatch?

On Mon, Jan 11, 2021 at 3:16 PM fanlu [email protected] wrote:

Try running it with cuda-memcheck. … <#m_-5088080202075003106_> On Fri, Jan 8, 2021 at 6:06 PM fanlu @.***> wrote: I have build the latest code with Debug mode, and train asg_train.py with K2_SYNC_KERNELS=1, but the problem RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED do not disappear. export K2_SYNC_KERNELS=1;export CUDA_VISIBLE_DEVICES=2;PYTHONPATH=/mnt/cfs1_alias1/asr/users/fanlu/task/snowfall:$PYTHONPATH; python asg_train.py — You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub <#572 (comment) https://github.com/k2-fsa/k2/issues/572#issuecomment-756667863>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLOZBRDIZFUGMLUNUTKDSY3KI5ANCNFSM4VYUPM2A .

@danpovey https://github.com/danpovey @csukuangfj https://github.com/csukuangfj I have used cuda-memcheck to check memory usage. Maybe there is a memory error in k2 when training long sequence. I am not sure.

Traceback (most recent call last):

File "mmi_bigram_train.py", line 141, in get_objf

(-tot_score).backward()

File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward

torch.autograd.backward(self, gradient, retain_graph, create_graph)

File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/autograd/init.py", line 130, in backward

Variable._execution_engine.run_backward(

RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

Traceback (most recent call last):

File "mmi_bigram_train.py", line 431, in

main()

File "mmi_bigram_train.py", line 378, in main

objf = train_one_epoch(dataloader=train_dl,

File "mmi_bigram_train.py", line 201, in train_one_epoch

get_objf(batch, model, P, device, graph_compiler, True, optimizer)

File "mmi_bigram_train.py", line 97, in get_objf

nnet_output = model(feature)

File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl

result = self.forward(*input, **kwargs)

File "/mnt/cfs1_alias1/asr/users/fanlu/task/snowfall/snowfall/models/tdnn_lstm.py", line 156, in forward

x_new, _ = lstm(x)

File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl

result = self.forward(*input, **kwargs)

File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 581, in forward

result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,

RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

========= LEAK SUMMARY: 0 bytes leaked in 0 allocations

========= ERROR SUMMARY: 1 error

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/k2/issues/572#issuecomment-757669509, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO3T52V7FRNVLPH77CTSZKQVHANCNFSM4VYUPM2A .

danpovey avatar Jan 11 '21 07:01 danpovey

Sorry, It's not repeatable, this error occurs at about 6k~8k, 10k~12k, 14k+, etc. So I do't know what batch can cause this problem. in Aishell train dataset, the longest sequence is about 14.53125 S. I have already trained the librispeech recipe, and got CER 11.44% without errors.

Mm. It may be possible to run with valgrind on CPU and it would find the error, but possibly the error would just disappear on CPU or would take too long to reach. Does it seem repeatable, e.g. at a particular minibatch? On Mon, Jan 11, 2021 at 3:16 PM fanlu @.> wrote: Try running it with cuda-memcheck. … <#m_-5088080202075003106_> On Fri, Jan 8, 2021 at 6:06 PM fanlu @.> wrote: I have build the latest code with Debug mode, and train asg_train.py with K2_SYNC_KERNELS=1, but the problem RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED do not disappear. export K2_SYNC_KERNELS=1;export CUDA_VISIBLE_DEVICES=2;PYTHONPATH=/mnt/cfs1_alias1/asr/users/fanlu/task/snowfall:$PYTHONPATH; python asg_train.py — You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub <#572 (comment) <#572 (comment)>>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLOZBRDIZFUGMLUNUTKDSY3KI5ANCNFSM4VYUPM2A . @danpovey https://github.com/danpovey @csukuangfj https://github.com/csukuangfj I have used cuda-memcheck to check memory usage. Maybe there is a memory error in k2 when training long sequence. I am not sure. Traceback (most recent call last): File "mmi_bigram_train.py", line 141, in get_objf (-tot_score).backward() File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/autograd/init.py", line 130, in backward Variable._execution_engine.run_backward( RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED Traceback (most recent call last): File "mmi_bigram_train.py", line 431, in main() File "mmi_bigram_train.py", line 378, in main objf = train_one_epoch(dataloader=train_dl, File "mmi_bigram_train.py", line 201, in train_one_epoch get_objf(batch, model, P, device, graph_compiler, True, optimizer) File "mmi_bigram_train.py", line 97, in get_objf nnet_output = model(feature) File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/mnt/cfs1_alias1/asr/users/fanlu/task/snowfall/snowfall/models/tdnn_lstm.py", line 156, in forward x_new, _ = lstm(x) File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/asr_storage/fanlu/miniconda3/envs/k2/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 581, in forward result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers, RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED ========= LEAK SUMMARY: 0 bytes leaked in 0 allocations ========= ERROR SUMMARY: 1 error — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#572 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO3T52V7FRNVLPH77CTSZKQVHANCNFSM4VYUPM2A .

fanlu avatar Jan 11 '21 08:01 fanlu