mamba icon indicating copy to clipboard operation
mamba copied to clipboard

loss nan after several epochs

Open JunMa11 opened this issue 1 year ago • 24 comments

Dear @albertfgu and @tridao,

Thanks for sharing the awesome work.

I'm trying to incorporate mamba into a CNN and the training goes well at the first a few epochs (100-750) but the loss runs into nan for unknown reasons in the following epochs.

I can confirm that nan occurred in mamba blocks because this issue disappeared after removing them. However, I have no idea how to debug this issue. Would it be possible for you to give me some hints to fix this issue?

Here are the logs. Please feel free to let me know if you need more details.

2023-12-20 00:48:15.329871: Epoch 629
2023-12-20 00:48:15.330647: Current learning rate: 6e-05
2023-12-20 00:50:44.726418: train_loss -0.7834
2023-12-20 00:50:44.728141: val_loss -0.7892
2023-12-20 00:50:44.729907: Epoch time: 149.4 s
2023-12-20 00:50:46.116153:
2023-12-20 00:50:46.117122: Epoch 630
2023-12-20 00:50:46.117851: Current learning rate: 6e-05
2023-12-20 00:53:15.241138: train_loss -0.7657
2023-12-20 00:53:15.243143: val_loss -0.7889
2023-12-20 00:53:15.244522: Epoch time: 149.13 s
2023-12-20 00:53:16.622655:
2023-12-20 00:53:16.623522: Epoch 631
2023-12-20 00:53:16.624225: Current learning rate: 6e-05
2023-12-20 00:55:45.777329: train_loss nan
2023-12-20 00:55:45.779537: val_loss -0.7908
2023-12-20 00:55:45.781009: Epoch time: 149.16 s
2023-12-20 00:55:47.172864:
2023-12-20 00:55:47.173798: Epoch 632
2023-12-20 00:55:47.174612: Current learning rate: 6e-05
2023-12-20 00:58:16.292152: train_loss nan
2023-12-20 00:58:16.293823: val_loss -0.7766
2023-12-20 00:58:16.295210: Epoch time: 149.12 s
2023-12-20 00:58:17.693203:
2023-12-20 00:58:17.694208: Epoch 633
2023-12-20 00:58:17.694858: Current learning rate: 6e-05
2023-12-20 01:00:46.692787: train_loss nan
2023-12-20 01:00:46.695605: val_loss -0.7742
2023-12-20 01:00:46.697324: Epoch time: 149.0 s
2023-12-20 01:00:48.099071:
2023-12-20 01:00:48.099914: Epoch 634
2023-12-20 01:00:48.100704: Current learning rate: 6e-05
2023-12-20 01:03:17.243219: train_loss nan
2023-12-20 01:03:17.245350: val_loss -0.7739
2023-12-20 01:03:17.246872: Epoch time: 149.15 s
2023-12-20 01:03:18.642867:
2023-12-20 01:03:18.643775: Epoch 635
2023-12-20 01:03:18.644512: Current learning rate: 6e-05
2023-12-20 01:05:47.672137: train_loss nan
2023-12-20 01:05:47.674026: val_loss -0.7846
2023-12-20 01:05:47.675463: Epoch time: 149.03 s
2023-12-20 01:05:49.063004:
2023-12-20 01:05:49.063804: Epoch 636
2023-12-20 01:05:49.064505: Current learning rate: 6e-05
2023-12-20 01:08:18.338836: train_loss nan
2023-12-20 01:08:18.341044: val_loss -0.7697
2023-12-20 01:08:18.342629: Epoch time: 149.28 s
2023-12-20 01:08:19.724974:
2023-12-20 01:08:19.725791: Epoch 637
2023-12-20 01:08:19.726496: Current learning rate: 6e-05
2023-12-20 01:10:48.727470: train_loss nan
2023-12-20 01:10:48.729979: val_loss -0.7753
2023-12-20 01:10:48.731454: Epoch time: 149.0 s
2023-12-20 01:10:50.143582:
2023-12-20 01:10:50.144486: Epoch 638
2023-12-20 01:10:50.145250: Current learning rate: 6e-05
2023-12-20 01:13:19.113122: train_loss nan
2023-12-20 01:13:19.114945: val_loss -0.79
2023-12-20 01:13:19.116437: Epoch time: 148.97 s
2023-12-20 01:13:20.502590:
2023-12-20 01:13:20.503405: Epoch 639
2023-12-20 01:13:20.504042: Current learning rate: 6e-05
2023-12-20 01:15:49.631775: train_loss nan
2023-12-20 01:15:49.633824: val_loss -0.7596
2023-12-20 01:15:49.635192: Epoch time: 149.13 s
2023-12-20 01:15:51.028755:
2023-12-20 01:15:51.029576: Epoch 640
2023-12-20 01:15:51.030221: Current learning rate: 6e-05
2023-12-20 01:18:20.038192: train_loss nan
2023-12-20 01:18:20.039899: val_loss -0.7918
2023-12-20 01:18:20.041280: Epoch time: 149.01 s
2023-12-20 01:18:21.429107:
2023-12-20 01:18:21.429842: Epoch 641
2023-12-20 01:18:21.430510: Current learning rate: 6e-05
2023-12-20 01:20:50.438274: train_loss nan
2023-12-20 01:20:50.440479: val_loss -0.7908
2023-12-20 01:20:50.441905: Epoch time: 149.01 s
2023-12-20 01:20:51.974714:
2023-12-20 01:20:51.975639: Epoch 642
2023-12-20 01:20:51.976502: Current learning rate: 6e-05
2023-12-20 01:23:21.094214: train_loss nan
2023-12-20 01:23:21.096366: val_loss -0.7719
2023-12-20 01:23:21.097925: Epoch time: 149.12 s
2023-12-20 01:23:22.477658:
2023-12-20 01:23:22.478508: Epoch 643
2023-12-20 01:23:22.479224: Current learning rate: 6e-05
2023-12-20 01:25:51.455657: train_loss nan
2023-12-20 01:25:51.457551: val_loss -0.7594
2023-12-20 01:25:51.458950: Epoch time: 148.98 s
2023-12-20 01:25:52.854940:
2023-12-20 01:25:52.855765: Epoch 644
2023-12-20 01:25:52.856454: Current learning rate: 6e-05
2023-12-20 01:28:21.948064: train_loss nan
2023-12-20 01:28:21.950029: val_loss -0.7908
2023-12-20 01:28:21.951765: Epoch time: 149.1 s
2023-12-20 01:28:23.363405:
2023-12-20 01:28:23.370147: Epoch 645
2023-12-20 01:28:23.376197: Current learning rate: 6e-05
2023-12-20 01:30:52.547230: train_loss nan
2023-12-20 01:30:52.549585: val_loss -0.7503
2023-12-20 01:30:52.551248: Epoch time: 149.19 s
2023-12-20 01:30:53.940921:
2023-12-20 01:30:53.941931: Epoch 646
2023-12-20 01:30:53.942829: Current learning rate: 6e-05
2023-12-20 01:33:22.897284: train_loss nan
2023-12-20 01:33:22.899078: val_loss nan
2023-12-20 01:33:22.900609: Epoch time: 148.96 s
2023-12-20 01:33:24.480330:
2023-12-20 01:33:24.481319: Epoch 647
2023-12-20 01:33:24.481999: Current learning rate: 6e-05
2023-12-20 01:35:53.530202: train_loss nan
2023-12-20 01:35:53.532310: val_loss -0.759
2023-12-20 01:35:53.533979: Epoch time: 149.05 s
2023-12-20 01:35:54.926128:
2023-12-20 01:35:54.927284: Epoch 648
2023-12-20 01:35:54.928841: Current learning rate: 6e-05
2023-12-20 01:38:23.353006: train_loss nan
2023-12-20 01:38:23.355011: val_loss nan
2023-12-20 01:38:23.356395: Epoch time: 148.43 s
2023-12-20 01:38:24.749097:
2023-12-20 01:38:24.749952: Epoch 649
2023-12-20 01:38:24.750711: Current learning rate: 6e-05
2023-12-20 01:40:50.282438: train_loss nan
2023-12-20 01:40:50.284293: val_loss nan
2023-12-20 01:40:50.285918: Epoch time: 145.54 s
2023-12-20 01:40:53.765802:
2023-12-20 01:40:53.766853: Epoch 650
2023-12-20 01:40:53.767912: Current learning rate: 6e-05
2023-12-20 01:43:19.223994: train_loss nan
2023-12-20 01:43:19.225788: val_loss nan
2023-12-20 01:43:19.227218: Epoch time: 145.46 s

JunMa11 avatar Dec 20 '23 20:12 JunMa11

After diving into the details, we found that the values output by selective_scan_fn are very large in magnitude, even though the inputs to the function have normal magnitudes. The output of Mamba caused overflow in the later step.

Could you please share your insights on why the output has such large values in magnitude? I have been using the setup as shown in README:

model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

For your reference: https://github.com/state-spaces/mamba/blob/eb2f7a520dd5e2949b7ae1c3ef44f6cb99faef5c/mamba_ssm/ops/selective_scan_interface.py#L37

Input value range (min, max):
u: (-0.27846455574035645, 51.10154724121094)
delta: (-11.5971097946167, 10.911433219909668)
A: (-16.0, -1.0)
B: (-17.975683212280273, 19.113466262817383)
C: (-14.787280082702637, 13.634928703308105)
D: (1.0, 1.0)
z: (-117.9038314819336, 110.53551483154297)
delta_bias: (-6.710234642028809, -2.2756412029266357)
delta_softplus  is True
Output value range:
out: (-13671.662109375, 15236.4501953125)
x: (-70.63794708251953, 42.27799987792969)
rest[0]: (-283149.8125, 1407336.75)  (the final value returned by selective_scan_fn)
After the last linear projection layer in Mamba:
out: (-173402.859375, 188013.328125)

Could these large outputs be avoided by changing the model's hyperparameters?

JunMa11 avatar Dec 21 '23 03:12 JunMa11

How are you initializing your weights and cached states? I found that the numbers blew up for me like this when I started from a random initialization of the cached ssm_state and conv_state instead of zeros.

rbitr avatar Dec 21 '23 16:12 rbitr

How are you initializing your weights and cached states? I found that the numbers blew up for me like this when I started from a random initialization of the cached ssm_state and conv_state instead of zeros.

Hi, the numbers are also blowing up for me. I tried normalizing the output with layer norm but didn't help much. I also noticed that both ssm_state and conv_state were actually not being used, which was also mentioned in #51. I would like to know if you have any insights into this. Thanks!

ff98li avatar Dec 24 '23 03:12 ff98li

I got the same problem when I replaced S4 block by Mamba (S6 block), the output was also exploding. In my task, when training, I used forward function and when inference I called allocate_inference_cache before using step function.

longdvt avatar Dec 24 '23 03:12 longdvt

I also got the same nan problem. Didn't investigate so don't know if it is because of values blowing up. All I can say is I got it in a run with a smaller batch size (256) but not a larger one (1024).

nongiga avatar Dec 24 '23 11:12 nongiga

Whether related to https://github.com/state-spaces/mamba?tab=readme-ov-file#troubleshooting.

ftgreat avatar Dec 25 '23 02:12 ftgreat

hi all,

Finally, we found that disabling AMP and reducing the learning rate can solve this issue.

JunMa11 avatar Jan 02 '24 19:01 JunMa11

Can you show me how to disable AMP @JunMa11 ?

longdvt avatar Jan 05 '24 23:01 longdvt

I found that disabling amp did not seem to help for my use case, plus lowering LR just diminishes metric performance with this specific use case dataset. Does anyone have further suggestions here?

xtwigs avatar Jan 12 '24 14:01 xtwigs

Can you show me how to disable AMP @JunMa11 ?

During the training process you don't use the "with autocast" to encapsulate the code. If you don't include that then Automatic Mixed Precision won't be implemented.

You can mitigate issues with NaNs from AMP by using gradient scaling, this is because AMP uses lower precision which means gradients can more easily get stuck at 0, so scaling makes this less likely: https://pytorch.org/docs/stable/amp.html#gradient-scaling

ElliottDyson avatar Feb 02 '24 23:02 ElliottDyson

I found that disabling amp did not seem to help for my use case, plus lowering LR just diminishes metric performance with this specific use case dataset. Does anyone have further suggestions here?

What optimiser are you using? I've found that for higher learning rates using RAdam instead of Adam helps immensely due to the internal learning rate warmup of its internal optimisation that RAdam implements over Adam, where Adam simply updates internally at the rate of the set learning rate. It also mitigates the need for any learning rate warmup.

Hope this helps.

ElliottDyson avatar Feb 02 '24 23:02 ElliottDyson

same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions

  1. fp32 train: increase GPU Mem + lower training speed
  2. decrease lr: slower convergence, performance degradation

CiaoHe avatar Mar 03 '24 13:03 CiaoHe

I encountered the same problem, and in my case, the training process is still unstable even after reducing the learning rate.

ziyuyuyuyu1 avatar Mar 19 '24 01:03 ziyuyuyuyu1

same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions

  1. fp32 train: increase GPU Mem + lower training speed
  2. decrease lr: slower convergence, performance degradation

Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.

EricPaul03 avatar Mar 27 '24 19:03 EricPaul03

same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions

  1. fp32 train: increase GPU Mem + lower training speed
  2. decrease lr: slower convergence, performance degradation

Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.

As far as I am aware, fp32 is the default state of the tensors and therefore shouldn't be needing autocast. Autocast is typically used to automatically cast specific tensors to a lower precision (e.g. fp16) whilst keeping others at their original precision (e.g. fp32). This may be your issue. Try not using autocast, it may or may not help, but it's definitely worth trying.

ElliottDyson avatar Mar 27 '24 20:03 ElliottDyson

same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions

  1. fp32 train: increase GPU Mem + lower training speed
  2. decrease lr: slower convergence, performance degradation

Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.

As far as I am aware, fp32 is the default state of the tensors and therefore shouldn't be needing autocast. Autocast is typically used to automatically cast specific tensors to a lower precision (e.g. fp16) whilst keeping others at their original precision (e.g. fp32). This may be your issue. Try not using autocast, it may or may not help, but it's definitely worth trying.

Thank you very much for your answer. After my repeated checks, I think that I mistakenly used normalization. Now I can successfully run my network. Thank you again

EricPaul03 avatar Mar 29 '24 14:03 EricPaul03

https://github.com/state-spaces/mamba/issues/72#issuecomment-2027318310

same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions

  1. fp32 train: increase GPU Mem + lower training speed
  2. decrease lr: slower convergence, performance degradation

Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.

As far as I am aware, fp32 is the default state of the tensors and therefore shouldn't be needing autocast. Autocast is typically used to automatically cast specific tensors to a lower precision (e.g. fp16) whilst keeping others at their original precision (e.g. fp32). This may be your issue. Try not using autocast, it may or may not help, but it's definitely worth trying.

Thank you very much for your answer. After my repeated checks, I think that I mistakenly used normalization. Now I can successfully run my network. Thank you again

Hi, I encounter the same issue, "Expected delta.scalar_type() == input_type to be true, but got false." , could you tell me how did you fix it? Thank you very much! @EricPaul03

qiaoqiaoLF avatar Apr 09 '24 07:04 qiaoqiaoLF

A trick solution for too large value (which causes nan output)

is_bad = torch.logical_or(torch.abs(feat) > 10, torch.isnan(feat))
feat[is_bad] = 0

shaochengyan avatar May 08 '24 08:05 shaochengyan

#72 (comment)

same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions

  1. fp32 train: increase GPU Mem + lower training speed
  2. decrease lr: slower convergence, performance degradation

Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.

As far as I am aware, fp32 is the default state of the tensors and therefore shouldn't be needing autocast. Autocast is typically used to automatically cast specific tensors to a lower precision (e.g. fp16) whilst keeping others at their original precision (e.g. fp32). This may be your issue. Try not using autocast, it may or may not help, but it's definitely worth trying.

Thank you very much for your answer. After my repeated checks, I think that I mistakenly used normalization. Now I can successfully run my network. Thank you again

Hi, I encounter the same issue, "Expected delta.scalar_type() == input_type to be true, but got false." , could you tell me how did you fix it? Thank you very much! @EricPaul03

I also encounter this issue. Do you find solutions? Thank you very much!

zjr2000 avatar May 28 '24 08:05 zjr2000

#72 (comment)

same issue here when utilizing bi-mamba implementation refer to https://github.com/hustvl/Vim/blob/main/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py. two ways can solve this: 1. fp32 training; 2. decrease lr. But both are not ideal solutions

  1. fp32 train: increase GPU Mem + lower training speed
  2. decrease lr: slower convergence, performance degradation

Hello, I want to know how to train with fp32? When I set with autocast(dtype=torch.float32): I encounter /mamba_ssm/ops/selective_scan_interface.py", line 301, in backward dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( RuntimeError: Expected delta.scalar_type() == input_type to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) Could you help me? Thanks so much.

As far as I am aware, fp32 is the default state of the tensors and therefore shouldn't be needing autocast. Autocast is typically used to automatically cast specific tensors to a lower precision (e.g. fp16) whilst keeping others at their original precision (e.g. fp32). This may be your issue. Try not using autocast, it may or may not help, but it's definitely worth trying.

Thank you very much for your answer. After my repeated checks, I think that I mistakenly used normalization. Now I can successfully run my network. Thank you again

Hi, I encounter the same issue, "Expected delta.scalar_type() == input_type to be true, but got false." , could you tell me how did you fix it? Thank you very much! @EricPaul03

I also encounter this issue. Do you find solutions? Thank you very much!

I just create a new conda environment and reinstall all the packages (--force-reinstall to avoid using cache). Then the problem disappear

qiaoqiaoLF avatar May 29 '24 17:05 qiaoqiaoLF

I increased the value of the input dimension and the problem went away

Jhin3433 avatar Jun 01 '24 08:06 Jhin3433