loss nan after several epochs
Dear @albertfgu and @tridao,
Thanks for sharing the awesome work.
I'm trying to incorporate mamba into a CNN and the training goes well at the first a few epochs (100-750) but the loss runs into nan for unknown reasons in the following epochs.
I can confirm that nan occurred in mamba blocks because this issue disappeared after removing them. However, I have no idea how to debug this issue. Would it be possible for you to give me some hints to fix this issue?
Here are the logs. Please feel free to let me know if you need more details.
2023-12-20 00:48:15.329871: Epoch 629
2023-12-20 00:48:15.330647: Current learning rate: 6e-05
2023-12-20 00:50:44.726418: train_loss -0.7834
2023-12-20 00:50:44.728141: val_loss -0.7892
2023-12-20 00:50:44.729907: Epoch time: 149.4 s
2023-12-20 00:50:46.116153:
2023-12-20 00:50:46.117122: Epoch 630
2023-12-20 00:50:46.117851: Current learning rate: 6e-05
2023-12-20 00:53:15.241138: train_loss -0.7657
2023-12-20 00:53:15.243143: val_loss -0.7889
2023-12-20 00:53:15.244522: Epoch time: 149.13 s
2023-12-20 00:53:16.622655:
2023-12-20 00:53:16.623522: Epoch 631
2023-12-20 00:53:16.624225: Current learning rate: 6e-05
2023-12-20 00:55:45.777329: train_loss nan
2023-12-20 00:55:45.779537: val_loss -0.7908
2023-12-20 00:55:45.781009: Epoch time: 149.16 s
2023-12-20 00:55:47.172864:
2023-12-20 00:55:47.173798: Epoch 632
2023-12-20 00:55:47.174612: Current learning rate: 6e-05
2023-12-20 00:58:16.292152: train_loss nan
2023-12-20 00:58:16.293823: val_loss -0.7766
2023-12-20 00:58:16.295210: Epoch time: 149.12 s
2023-12-20 00:58:17.693203:
2023-12-20 00:58:17.694208: Epoch 633
2023-12-20 00:58:17.694858: Current learning rate: 6e-05
2023-12-20 01:00:46.692787: train_loss nan
2023-12-20 01:00:46.695605: val_loss -0.7742
2023-12-20 01:00:46.697324: Epoch time: 149.0 s
2023-12-20 01:00:48.099071:
2023-12-20 01:00:48.099914: Epoch 634
2023-12-20 01:00:48.100704: Current learning rate: 6e-05
2023-12-20 01:03:17.243219: train_loss nan
2023-12-20 01:03:17.245350: val_loss -0.7739
2023-12-20 01:03:17.246872: Epoch time: 149.15 s
2023-12-20 01:03:18.642867:
2023-12-20 01:03:18.643775: Epoch 635
2023-12-20 01:03:18.644512: Current learning rate: 6e-05
2023-12-20 01:05:47.672137: train_loss nan
2023-12-20 01:05:47.674026: val_loss -0.7846
2023-12-20 01:05:47.675463: Epoch time: 149.03 s
2023-12-20 01:05:49.063004:
2023-12-20 01:05:49.063804: Epoch 636
2023-12-20 01:05:49.064505: Current learning rate: 6e-05
2023-12-20 01:08:18.338836: train_loss nan
2023-12-20 01:08:18.341044: val_loss -0.7697
2023-12-20 01:08:18.342629: Epoch time: 149.28 s
2023-12-20 01:08:19.724974:
2023-12-20 01:08:19.725791: Epoch 637
2023-12-20 01:08:19.726496: Current learning rate: 6e-05
2023-12-20 01:10:48.727470: train_loss nan
2023-12-20 01:10:48.729979: val_loss -0.7753
2023-12-20 01:10:48.731454: Epoch time: 149.0 s
2023-12-20 01:10:50.143582:
2023-12-20 01:10:50.144486: Epoch 638
2023-12-20 01:10:50.145250: Current learning rate: 6e-05
2023-12-20 01:13:19.113122: train_loss nan
2023-12-20 01:13:19.114945: val_loss -0.79
2023-12-20 01:13:19.116437: Epoch time: 148.97 s
2023-12-20 01:13:20.502590:
2023-12-20 01:13:20.503405: Epoch 639
2023-12-20 01:13:20.504042: Current learning rate: 6e-05
2023-12-20 01:15:49.631775: train_loss nan
2023-12-20 01:15:49.633824: val_loss -0.7596
2023-12-20 01:15:49.635192: Epoch time: 149.13 s
2023-12-20 01:15:51.028755:
2023-12-20 01:15:51.029576: Epoch 640
2023-12-20 01:15:51.030221: Current learning rate: 6e-05
2023-12-20 01:18:20.038192: train_loss nan
2023-12-20 01:18:20.039899: val_loss -0.7918
2023-12-20 01:18:20.041280: Epoch time: 149.01 s
2023-12-20 01:18:21.429107:
2023-12-20 01:18:21.429842: Epoch 641
2023-12-20 01:18:21.430510: Current learning rate: 6e-05
2023-12-20 01:20:50.438274: train_loss nan
2023-12-20 01:20:50.440479: val_loss -0.7908
2023-12-20 01:20:50.441905: Epoch time: 149.01 s
2023-12-20 01:20:51.974714:
2023-12-20 01:20:51.975639: Epoch 642
2023-12-20 01:20:51.976502: Current learning rate: 6e-05
2023-12-20 01:23:21.094214: train_loss nan
2023-12-20 01:23:21.096366: val_loss -0.7719
2023-12-20 01:23:21.097925: Epoch time: 149.12 s
2023-12-20 01:23:22.477658:
2023-12-20 01:23:22.478508: Epoch 643
2023-12-20 01:23:22.479224: Current learning rate: 6e-05
2023-12-20 01:25:51.455657: train_loss nan
2023-12-20 01:25:51.457551: val_loss -0.7594
2023-12-20 01:25:51.458950: Epoch time: 148.98 s
2023-12-20 01:25:52.854940:
2023-12-20 01:25:52.855765: Epoch 644
2023-12-20 01:25:52.856454: Current learning rate: 6e-05
2023-12-20 01:28:21.948064: train_loss nan
2023-12-20 01:28:21.950029: val_loss -0.7908
2023-12-20 01:28:21.951765: Epoch time: 149.1 s
2023-12-20 01:28:23.363405:
2023-12-20 01:28:23.370147: Epoch 645
2023-12-20 01:28:23.376197: Current learning rate: 6e-05
2023-12-20 01:30:52.547230: train_loss nan
2023-12-20 01:30:52.549585: val_loss -0.7503
2023-12-20 01:30:52.551248: Epoch time: 149.19 s
2023-12-20 01:30:53.940921:
2023-12-20 01:30:53.941931: Epoch 646
2023-12-20 01:30:53.942829: Current learning rate: 6e-05
2023-12-20 01:33:22.897284: train_loss nan
2023-12-20 01:33:22.899078: val_loss nan
2023-12-20 01:33:22.900609: Epoch time: 148.96 s
2023-12-20 01:33:24.480330:
2023-12-20 01:33:24.481319: Epoch 647
2023-12-20 01:33:24.481999: Current learning rate: 6e-05
2023-12-20 01:35:53.530202: train_loss nan
2023-12-20 01:35:53.532310: val_loss -0.759
2023-12-20 01:35:53.533979: Epoch time: 149.05 s
2023-12-20 01:35:54.926128:
2023-12-20 01:35:54.927284: Epoch 648
2023-12-20 01:35:54.928841: Current learning rate: 6e-05
2023-12-20 01:38:23.353006: train_loss nan
2023-12-20 01:38:23.355011: val_loss nan
2023-12-20 01:38:23.356395: Epoch time: 148.43 s
2023-12-20 01:38:24.749097:
2023-12-20 01:38:24.749952: Epoch 649
2023-12-20 01:38:24.750711: Current learning rate: 6e-05
2023-12-20 01:40:50.282438: train_loss nan
2023-12-20 01:40:50.284293: val_loss nan
2023-12-20 01:40:50.285918: Epoch time: 145.54 s
2023-12-20 01:40:53.765802:
2023-12-20 01:40:53.766853: Epoch 650
2023-12-20 01:40:53.767912: Current learning rate: 6e-05
2023-12-20 01:43:19.223994: train_loss nan
2023-12-20 01:43:19.225788: val_loss nan
2023-12-20 01:43:19.227218: Epoch time: 145.46 s
After diving into the details, we found that the values output by selective_scan_fn are very large in magnitude, even though the inputs to the function have normal magnitudes. The output of Mamba caused overflow in the later step.
Could you please share your insights on why the output has such large values in magnitude? I have been using the setup as shown in README:
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
For your reference: https://github.com/state-spaces/mamba/blob/eb2f7a520dd5e2949b7ae1c3ef44f6cb99faef5c/mamba_ssm/ops/selective_scan_interface.py#L37
Input value range (min, max):
u: (-0.27846455574035645, 51.10154724121094)
delta: (-11.5971097946167, 10.911433219909668)
A: (-16.0, -1.0)
B: (-17.975683212280273, 19.113466262817383)
C: (-14.787280082702637, 13.634928703308105)
D: (1.0, 1.0)
z: (-117.9038314819336, 110.53551483154297)
delta_bias: (-6.710234642028809, -2.2756412029266357)
delta_softplus is True
Output value range:
out: (-13671.662109375, 15236.4501953125)
x: (-70.63794708251953, 42.27799987792969)
rest[0]: (-283149.8125, 1407336.75) (the final value returned by selective_scan_fn)
After the last linear projection layer in Mamba:
out: (-173402.859375, 188013.328125)
Could these large outputs be avoided by changing the model's hyperparameters?
How are you initializing your weights and cached states? I found that the numbers blew up for me like this when I started from a random initialization of the cached ssm_state and conv_state instead of zeros.
How are you initializing your weights and cached states? I found that the numbers blew up for me like this when I started from a random initialization of the cached ssm_state and conv_state instead of zeros.
Hi, the numbers are also blowing up for me. I tried normalizing the output with layer norm but didn't help much. I also noticed that both ssm_state and conv_state were actually not being used, which was also mentioned in #51. I would like to know if you have any insights into this. Thanks!
I got the same problem when I replaced S4 block by Mamba (S6 block), the output was also exploding. In my task, when training, I used forward function and when inference I called allocate_inference_cache before using step function.
I also got the same nan problem. Didn't investigate so don't know if it is because of values blowing up. All I can say is I got it in a run with a smaller batch size (256) but not a larger one (1024).
Whether related to https://github.com/state-spaces/mamba?tab=readme-ov-file#troubleshooting.
hi all,
Finally, we found that disabling AMP and reducing the learning rate can solve this issue.
Can you show me how to disable AMP @JunMa11 ?
I found that disabling amp did not seem to help for my use case, plus lowering LR just diminishes metric performance with this specific use case dataset. Does anyone have further suggestions here?
Can you show me how to disable AMP @JunMa11 ?
During the training process you don't use the "with autocast" to encapsulate the code. If you don't include that then Automatic Mixed Precision won't be implemented.
You can mitigate issues with NaNs from AMP by using gradient scaling, this is because AMP uses lower precision which means gradients can more easily get stuck at 0, so scaling makes this less likely: https://pytorch.org/docs/stable/amp.html#gradient-scaling
I found that disabling amp did not seem to help for my use case, plus lowering LR just diminishes metric performance with this specific use case dataset. Does anyone have further suggestions here?
What optimiser are you using? I've found that for higher learning rates using RAdam instead of Adam helps immensely due to the internal learning rate warmup of its internal optimisation that RAdam implements over Adam, where Adam simply updates internally at the rate of the set learning rate. It also mitigates the need for any learning rate warmup.
Hope this helps.
same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions
- fp32 train: increase GPU Mem + lower training speed
- decrease lr: slower convergence, performance degradation
I encountered the same problem, and in my case, the training process is still unstable even after reducing the learning rate.
same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions
- fp32 train: increase GPU Mem + lower training speed
- decrease lr: slower convergence, performance degradation
Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.
same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions
- fp32 train: increase GPU Mem + lower training speed
- decrease lr: slower convergence, performance degradation
Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.
As far as I am aware, fp32 is the default state of the tensors and therefore shouldn't be needing autocast. Autocast is typically used to automatically cast specific tensors to a lower precision (e.g. fp16) whilst keeping others at their original precision (e.g. fp32). This may be your issue. Try not using autocast, it may or may not help, but it's definitely worth trying.
same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions
- fp32 train: increase GPU Mem + lower training speed
- decrease lr: slower convergence, performance degradation
Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.
As far as I am aware, fp32 is the default state of the tensors and therefore shouldn't be needing autocast. Autocast is typically used to automatically cast specific tensors to a lower precision (e.g. fp16) whilst keeping others at their original precision (e.g. fp32). This may be your issue. Try not using autocast, it may or may not help, but it's definitely worth trying.
Thank you very much for your answer. After my repeated checks, I think that I mistakenly used normalization. Now I can successfully run my network. Thank you again
https://github.com/state-spaces/mamba/issues/72#issuecomment-2027318310
same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions
- fp32 train: increase GPU Mem + lower training speed
- decrease lr: slower convergence, performance degradation
Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.
As far as I am aware, fp32 is the default state of the tensors and therefore shouldn't be needing autocast. Autocast is typically used to automatically cast specific tensors to a lower precision (e.g. fp16) whilst keeping others at their original precision (e.g. fp32). This may be your issue. Try not using autocast, it may or may not help, but it's definitely worth trying.
Thank you very much for your answer. After my repeated checks, I think that I mistakenly used normalization. Now I can successfully run my network. Thank you again
Hi, I encounter the same issue, "Expected delta.scalar_type() == input_type to be true, but got false." , could you tell me how did you fix it? Thank you very much! @EricPaul03
A trick solution for too large value (which causes nan output)
is_bad = torch.logical_or(torch.abs(feat) > 10, torch.isnan(feat))
feat[is_bad] = 0
same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions
- fp32 train: increase GPU Mem + lower training speed
- decrease lr: slower convergence, performance degradation
Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.
As far as I am aware, fp32 is the default state of the tensors and therefore shouldn't be needing autocast. Autocast is typically used to automatically cast specific tensors to a lower precision (e.g. fp16) whilst keeping others at their original precision (e.g. fp32). This may be your issue. Try not using autocast, it may or may not help, but it's definitely worth trying.
Thank you very much for your answer. After my repeated checks, I think that I mistakenly used normalization. Now I can successfully run my network. Thank you again
Hi, I encounter the same issue, "Expected delta.scalar_type() == input_type to be true, but got false." , could you tell me how did you fix it? Thank you very much! @EricPaul03
I also encounter this issue. Do you find solutions? Thank you very much!
same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions
- fp32 train: increase GPU Mem + lower training speed
- decrease lr: slower convergence, performance degradation
Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.
As far as I am aware, fp32 is the default state of the tensors and therefore shouldn't be needing autocast. Autocast is typically used to automatically cast specific tensors to a lower precision (e.g. fp16) whilst keeping others at their original precision (e.g. fp32). This may be your issue. Try not using autocast, it may or may not help, but it's definitely worth trying.
Thank you very much for your answer. After my repeated checks, I think that I mistakenly used normalization. Now I can successfully run my network. Thank you again
Hi, I encounter the same issue, "Expected delta.scalar_type() == input_type to be true, but got false." , could you tell me how did you fix it? Thank you very much! @EricPaul03
I also encounter this issue. Do you find solutions? Thank you very much!
I just create a new conda environment and reinstall all the packages (--force-reinstall to avoid using cache). Then the problem disappear
I increased the value of the input dimension and the problem went away