mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Add support for left padding and masking in forward() and generate()

Open normster opened this issue 1 year ago • 20 comments
trafficstars

This PR implements masking for left contiguous pad tokens by zeroing out intermediate state values, per the discussion at https://github.com/state-spaces/mamba/issues/66, for all three code paths: non-fused, fused without CUDA graph, and fused with CUDA graph. I'm not sure if this implementation is the best approach, so let me know if there's a better way to do things.

I've included a simple testing script at tests/test_padding.py which can be run with python tests/test_padding.py to compare prefill logits + generation outputs with and without left padding.

I also evaluated the models with/without batching + left-padding + masking on a question answering dataset and found nearly identical accuracies. Batching + left-padding + no masking hurts accuracy by a couple percentage points.

normster avatar Dec 20 '23 08:12 normster

I tried validating the masking implementation with lm-eval-harness. ~~On HellaSwag, mamba-1.4b with right padding still achieves the reported 59.1% accuracy. Switching to left padding drops this to 55.8% accuracy, and changing lm-eval-harness to 1) construct padding masks and 2) use padding masks does not recover any performance (still 55.8%). I might be using the padding masks incorrectly but it's pretty straightforward so I suspect the issue might lie in my mamba masking change in this PR.~~

Edit: I found an error in my evaluation of lm-eval-harness with left padding. After fixing, I get 59.1% with left padding + masking, the same as with right-padding. But left padding + no masking also gives 59.1% since lm-eval-harness collates prompts by length which minimizes the number of padding tokens. I found no difference in pythia-1.4b performance left padding with/without masking (52.1% as reported with right padding). Switching to a fixed, random collate function exposes a difference in performance on pythia-1.4b: 43.4% without masking and 52.1% with masking. But mamba-1.4b is virtually unchanged (59.0%). Maybe it's just more robust to long runs of unmasked padding tokens?

TL;DR: I think my proposed padding + masking works, though it's not clear mamba even really needs the masking.

normster avatar Dec 20 '23 09:12 normster

I just want to express my interest in left padding and masking. Thanks for the effort.

pjsample avatar Dec 21 '23 22:12 pjsample

Curious how to mask to train on outputs only if you think masking isn't needed.

thistleknot avatar Dec 22 '23 02:12 thistleknot

Curious how to mask to train on outputs only if you think masking isn't needed.

My understanding is that the output of the model (token logits) is causal by default, so there is no masking when the model is being trained autoregressively. For an idea on how to train the model look here: https://github.com/havenhq/mamba-chat/blob/main/train_mamba.py

pjsample avatar Dec 22 '23 18:12 pjsample

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

sentialx avatar Dec 24 '23 11:12 sentialx

I think if you want pad+mask to be effective, you need to do pre-training without using a full sentence in chunk

junphine avatar Jan 02 '24 03:01 junphine

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

sunningmbzuai avatar Jan 08 '24 16:01 sunningmbzuai

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.

On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).

Can someone verify if my thought process is accurate?

xtwigs avatar Jan 11 '24 23:01 xtwigs

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.

On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).

Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

sunningmbzuai avatar Jan 18 '24 17:01 sunningmbzuai

@normster @tridao @albertfgu I believe this feature would be very nice to have in a stable release. Can we work towards merging this into main and have it in the next stable release? I am happy to help in any way.

abdulfatir avatar Feb 18 '24 10:02 abdulfatir

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue. On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?). Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

zigzagcai avatar Feb 29 '24 07:02 zigzagcai

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue.

On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?).

Can someone verify if my thought process is accurate?

I have verified your idea and it seems not to work and causes CUDA OOM. Err Msg:

  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/some_path/zigzagcai/devel/mamba/mamba_ssm/modules/mamba_simple.py", line 146, in forward
    out = mamba_inner_fn(
  File "/some_path/zigzagcai/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 310, in mamba_inner_fn
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 97, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/some_path/zigzagcai/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 221, in forward
    out, scan_intermediates, out_z = selective_scan_cuda.fwd(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 79.33 GiB total capacity; 74.83 GiB already allocated; 313.81 MiB free; 77.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My code changes:

diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py
index 35143ad..6792749 100644
--- a/mamba_ssm/ops/selective_scan_interface.py
+++ b/mamba_ssm/ops/selective_scan_interface.py
@@ -159,7 +159,7 @@ class MambaInnerFn(torch.autograd.Function):
     def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                 out_proj_weight, out_proj_bias,
                 A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1):
+                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=0):
         """
              xz: (batch, dim, seqlen)
         """
@@ -298,7 +298,7 @@ class MambaInnerFn(torch.autograd.Function):
                 dout_proj_weight, dout_proj_bias,
                 dA, dB, dC, dD,
                 ddelta_bias if delta_bias is not None else None,
-                dB_proj_bias, dC_proj_bias, None)
+                dB_proj_bias, dC_proj_bias, None, None)

zigzagcai avatar Feb 29 '24 08:02 zigzagcai

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue. On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?). Can someone verify if my thought process is accurate?

I have verified your idea and it seems not to work and causes CUDA OOM. Err Msg:

  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/hwfile/caizheng/devel/mamba/mamba_ssm/modules/mamba_simple.py", line 146, in forward
    out = mamba_inner_fn(
  File "/mnt/hwfile/caizheng/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 310, in mamba_inner_fn
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 97, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/mnt/hwfile/caizheng/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 221, in forward
    out, scan_intermediates, out_z = selective_scan_cuda.fwd(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 79.33 GiB total capacity; 74.83 GiB already allocated; 313.81 MiB free; 77.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My code changes:

diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py
index 35143ad..6792749 100644
--- a/mamba_ssm/ops/selective_scan_interface.py
+++ b/mamba_ssm/ops/selective_scan_interface.py
@@ -159,7 +159,7 @@ class MambaInnerFn(torch.autograd.Function):
     def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                 out_proj_weight, out_proj_bias,
                 A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1):
+                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=0):
         """
              xz: (batch, dim, seqlen)
         """
@@ -298,7 +298,7 @@ class MambaInnerFn(torch.autograd.Function):
                 dout_proj_weight, dout_proj_bias,
                 dA, dB, dC, dD,
                 ddelta_bias if delta_bias is not None else None,
-                dB_proj_bias, dC_proj_bias, None)
+                dB_proj_bias, dC_proj_bias, None, None)

Is it using the default configuration defined in MambaConfig? The default values specify a huge network with 64 layers and embedding size of 2560. And checkpoint_lvl=0 disables checkpoint and then asks the forward pass to keep convolution and delta results in the GPU memory.

enneamer avatar Feb 29 '24 08:02 enneamer

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue. On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?). Can someone verify if my thought process is accurate?

I have verified your idea and it seems not to work and causes CUDA OOM. Err Msg:

  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/some_path/zigzagcai/devel/mamba/mamba_ssm/modules/mamba_simple.py", line 146, in forward
    out = mamba_inner_fn(
  File "/some_path/zizgagcai/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 310, in mamba_inner_fn
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  File "/some_path/miniconda3/envs/final_dev/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 97, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/some_path/zigzagcai/devel/mamba/mamba_ssm/ops/selective_scan_interface.py", line 221, in forward
    out, scan_intermediates, out_z = selective_scan_cuda.fwd(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 79.33 GiB total capacity; 74.83 GiB already allocated; 313.81 MiB free; 77.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My code changes:

diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py
index 35143ad..6792749 100644
--- a/mamba_ssm/ops/selective_scan_interface.py
+++ b/mamba_ssm/ops/selective_scan_interface.py
@@ -159,7 +159,7 @@ class MambaInnerFn(torch.autograd.Function):
     def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                 out_proj_weight, out_proj_bias,
                 A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
-                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1):
+                C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=0):
         """
              xz: (batch, dim, seqlen)
         """
@@ -298,7 +298,7 @@ class MambaInnerFn(torch.autograd.Function):
                 dout_proj_weight, dout_proj_bias,
                 dA, dB, dC, dD,
                 ddelta_bias if delta_bias is not None else None,
-                dB_proj_bias, dC_proj_bias, None)
+                dB_proj_bias, dC_proj_bias, None, None)

Is it using the default configuration defined in MambaConfig? The default values specify a huge network with 64 layers and embedding size of 2560. And checkpoint_lvl=0 disables checkpoint and then asks the forward pass to keep convolution and delta results in the GPU memory.

No. I am evaluating mamba model with 1.4B parameter size, where layers = 48 and model dimension = 2048 are equivalent to the size on the repo page. As we know from the mamba paper, when I set checkpoint_lvl=0, it will disable the recomputation of conv1d_out, delta in backward pass and store those values in GPU memory, which leads to much more memory usage.

Here is my experiment details:

  1. original 1.4B mamba model: works well with 1x node with 8x A100 GPUs
  2. 1.4B mamba model patched with this PR: encounters OOM even with 4x nodes with 32x A100 GPUs

zigzagcai avatar Feb 29 '24 08:02 zigzagcai

We are trying this PR because we want mamba to process packed sequence like what has been done in transformer-based models. If we directly pad the sequence with zero, then a lot of computation will be wasted on meaningless padded tokens. We just want to use mask to mark the meaningless padded tokens and let the computation focus on regular tokens.

zigzagcai avatar Feb 29 '24 09:02 zigzagcai

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue. On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?). Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0)

xtwigs avatar Feb 29 '24 12:02 xtwigs

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue. On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?). Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0)

Hello, thanks for the sharing! :D I might be wrong. But when I tried this branch and found that OOM still appeared in the backward pass, meanwhile python tests/test_padding.py leads to huge max L2 errors (padded): tensor(3442.7573) in the forward pass, which did not appear in Norman’s original PR and might indicated the left padding and masking not work in the forward pass.

Is there any reproducible test code snippet indicating left padding+ masking works in both fwd and bwd pass? Thanks!

zigzagcai avatar Mar 04 '24 03:03 zigzagcai

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue. On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?). Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0)

Hello, thanks for the sharing! :D I might be wrong. But when I tried this branch and found that OOM still appeared in the backward pass, meanwhile python tests/test_padding.py leads to huge max L2 errors (padded): tensor(3442.7573) in the forward pass, which did not appear in Norman’s original PR and might indicated the left padding and masking not work in the forward pass.

Is there any reproducible test code snippet indicating left padding+ masking works in both fwd and bwd pass? Thanks!

Can you try this while disabling the dropout module added in the mamba simple code? (default was set to 0.1)

xtwigs avatar Mar 04 '24 15:03 xtwigs

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue. On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?). Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0)

Hello, thanks for the sharing! :D I might be wrong. But when I tried this branch and found that OOM still appeared in the backward pass, meanwhile python tests/test_padding.py leads to huge max L2 errors (padded): tensor(3442.7573) in the forward pass, which did not appear in Norman’s original PR and might indicated the left padding and masking not work in the forward pass. Is there any reproducible test code snippet indicating left padding+ masking works in both fwd and bwd pass? Thanks!

Can you try this while disabling the dropout module added in the mamba simple code? (default was set to 0.1)

Hi xtwigs, many thanks for your reply! I tried the latest code of your branch and found mamba block runnable without error, but when I run the test code, the max L2 errors are still relatively large.

zigzagcai avatar Mar 11 '24 10:03 zigzagcai

I tried using this branch but got an error about not getting expected number of gradients during backward (15 vs 16)

Yeah, I got the same error. Does anyone know how to solve it?

I think this error happens because Pytorch expects a gradient for the mask to be returned, which isn't the case here. Adding a None in mamba_ssm/ops/selective_scan_interface.py#301 fixes this issue. On the other hand to bypass the recalculation of conv1d in the backward pass we can set checkpoint_level to 0 (?). Can someone verify if my thought process is accurate?

Could you provide your change in mamba_ssm/ops/selective_scan_interface.py#301 ? That may help a lot.

Hi, I also meet this error. Could you please provide some insight about how to fix this issue on the backward pass?

Hi! Sorry, but I haven't run into the OOM issue. I keep a fork of this code here in case it might help. (I also run this with checkpoint_lvl=0)

Hello, thanks for the sharing! :D I might be wrong. But when I tried this branch and found that OOM still appeared in the backward pass, meanwhile python tests/test_padding.py leads to huge max L2 errors (padded): tensor(3442.7573) in the forward pass, which did not appear in Norman’s original PR and might indicated the left padding and masking not work in the forward pass. Is there any reproducible test code snippet indicating left padding+ masking works in both fwd and bwd pass? Thanks!

Can you try this while disabling the dropout module added in the mamba simple code? (default was set to 0.1)

Hi xtwigs, many thanks for your reply! I tried the latest code of your branch and found mamba block runnable without error, but when I run the test code, the max L2 errors are still relatively large.

Hi zigzagcai , have you solved the problem of masking in the mamba block ?

laulampaul avatar Apr 01 '24 04:04 laulampaul