mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Function 'MambaSplitConv1dScanCombinedFnBackward' returned nan values in its 0th output.

Open Liamhh20230518 opened this issue 1 year ago • 8 comments
trafficstars

Screenshot 2024-07-05 at 6 24 41 PM

mamba-ssm: 2.0.4

The loss is not nan, but the MambaSplitConv1dScanCombinedFnBackward' returned nan values in its 0th output

Thanks

Liamhh20230518 avatar Jul 05 '24 08:07 Liamhh20230518

Can you post a script that help us reproduce the error? E.g. save the tensors that produce the Nan?

tridao avatar Jul 05 '24 19:07 tridao

@tridao I encountered the same issue in my project. The loss is not nan, but it seems that the grad from one mamba func is nan. Screenshot 2024-07-11 at 12 13 36

I checked the output of mamba_split_conv1d_scan_combined. It shows that the forward result is ok. Screenshot 2024-07-11 at 12 21 51

Please help here. Thanks!

ArlenCHEN avatar Jul 11 '24 04:07 ArlenCHEN

Can you post a script that help us reproduce the error?

tridao avatar Jul 11 '24 04:07 tridao

Hi, I noticed that I encountered the same issue as he did. The problem I'm dealing with is that when I use the mamba module independently, it doesn't output any NaN values. However, when I integrate it into my model, NaN values start to appear. Interestingly, it works fine on one server, but when I switch to another server, the NaN values show up. Both machines are Dockerized, so I don't expect any environmental issues, and I've also double-checked things like nvcc and related configurations.

I’ve been trying to resolve this issue for a while now, but I haven’t been able to find a solution yet. I was wondering if you have come across any potential solutions or workarounds for this issue? I would greatly appreciate any insights or suggestions you might have.

Thank you in advance for your time and help!

I ran the same code on different servers and output the results at the mamba encoding step. The result was as follows: Snipaste_2025-01-13_10-50-08 Snipaste_2025-01-13_10-51-03

One outputs all NaN values, while the other outputs all zeros.

Further investigation revealed that NaN values occurred when defining the convolutional layer: Snipaste_2025-01-13_13-34-54 Snipaste_2025-01-13_13-35-11

Based on the above findings, I managed to address the issue of NaN values during Mamba encoding. I identified that the problem was caused by NaN values occurring during the initialization of the convolutional layer. By reinitializing the convolutional layer, the NaN issue was resolved.However, I am not sure if it will affect the results. I think it shouldn't have any impact.

Here is a portion of my code.

    def __init__(
            self,
            d_model,
            d_state=128,
            d_conv=4,
            conv_init=None,
            expand=2,
            headdim=64,
            ngroups=1,
            A_init_range=(1, 16),
            dt_min=0.001,
            dt_max=0.1,
            dt_init_floor=1e-4,
            dt_limit=(0.0, float("inf")),
            learnable_init_states=False,
            activation="swish",
            bias=False,
            conv_bias=True,
            # Fused kernel and sharding options
            chunk_size=256,
            use_mem_eff_path=True,
            layer_idx=None,  # Absorb kwarg for general module
            device=None,
            dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.conv_init = conv_init
        self.expand = expand
        self.d_inner = self.expand * self.d_model
        self.headdim = headdim
        self.ngroups = ngroups
        assert self.d_inner % self.headdim == 0
        self.nheads = self.d_inner // self.headdim
        self.dt_limit = dt_limit
        self.learnable_init_states = learnable_init_states
        self.activation = activation
        self.chunk_size = chunk_size
        self.use_mem_eff_path = use_mem_eff_path
        self.layer_idx = layer_idx

        # Order: [z, x, B, C, dt]
        d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
        self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)

        conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
        self.conv1d = nn.Conv1d(
            in_channels=conv_dim,
            out_channels=conv_dim,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=conv_dim,
            padding=d_conv - 1,
            **factory_kwargs,
        )
        if self.conv_init is not None:
            nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
        # self.conv1d.weight._no_weight_decay = True
        # self.conv1d.bias._no_weight_decay = True
        if self.learnable_init_states:
            self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))
            self.init_states._no_weight_decay = True

        self.act = nn.SiLU()

        # Initialize log dt bias
        dt = torch.exp(
            torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        )
        dt = torch.clamp(dt, min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        self.dt_bias = nn.Parameter(inv_dt)
        # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
        # name.endswith("bias") in param_grouping.py
        self.dt_bias._no_weight_decay = True

        # A parameter
        assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
        A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
        A_log = torch.log(A).to(dtype=dtype)
        self.A_log = nn.Parameter(A_log)
        # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
        self.A_log._no_weight_decay = True

        # D "skip" parameter
        self.D = nn.Parameter(torch.ones(self.nheads, device=device))
        self.D._no_weight_decay = True

        # Extra normalization layer right before output projection
        assert RMSNormGated is not None
        self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)

        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)

    def forward(self, u, seq_idx=None):
        """
        u: (B, L, D)
        Returns: same shape as u
        """
        batch, seqlen, dim = u.shape

        zxbcdt = self.in_proj(u)  # (B, L, d_in_proj)
        A = -torch.exp(self.A_log)  # (nheads) or (d_inner, d_state)
        initial_states = repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None
        dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)

        print(torch.isnan(self.conv1d.bias).any())
        print(torch.isnan(self.conv1d.weight).any())
        print(torch.isnan(A).any())
        print(torch.isnan(self.D).any())
        print(torch.isnan(self.norm.weight).any())
        print(torch.isnan(self.out_proj.weight).any())
        if self.use_mem_eff_path:
            # Fully fused path
            out = mamba_split_conv1d_scan_combined(
                zxbcdt,
                rearrange(self.conv1d.weight, "d 1 w -> d w"),
                self.conv1d.bias,
                self.dt_bias,
                A,
                D=self.D,
                chunk_size=self.chunk_size,
                seq_idx=seq_idx,
                activation=self.activation,
                rmsnorm_weight=self.norm.weight,
                rmsnorm_eps=self.norm.eps,
                outproj_weight=self.out_proj.weight,
                outproj_bias=self.out_proj.bias,
                headdim=self.headdim,
                ngroups=self.ngroups,
                norm_before_gate=False,
                initial_states=initial_states,
                **dt_limit_kwargs,
            )
            # print(out)
        else:
            z, xBC, dt = torch.split(
                zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
            )
            dt = F.softplus(dt + self.dt_bias)  # (B, L, nheads)
            assert self.activation in ["silu", "swish"]

            # 1D Convolution
            if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
                xBC = self.act(
                    self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
                )  # (B, L, self.d_inner + 2 * ngroups * d_state)
                xBC = xBC[:, :seqlen, :]
            else:
                xBC = causal_conv1d_fn(
                    x=xBC.transpose(1, 2),
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    bias=self.conv1d.bias,
                    activation=self.activation,
                ).transpose(1, 2)

            # Split into 3 main branches: X, B, C
            # These correspond to V, K, Q respectively in the SSM/attention duality
            x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
            y = mamba_chunk_scan_combined(
                rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
                dt,
                A,
                rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
                rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
                chunk_size=self.chunk_size,
                D=self.D,
                z=None,
                seq_idx=seq_idx,
                initial_states=initial_states,
                **dt_limit_kwargs,
            )
            y = rearrange(y, "b l h p -> b l (h p)")

            # Multiply "gate" branch and apply extra normalization layer
            y = self.norm(y, z)
            out = self.out_proj(y)
        return out
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
                    
if __name__ == '__main__':
    from transformers import BertConfig
    import utils
    import os
    import torch
    params = utils.Params()
    # Prepare model
    bert_config = BertConfig.from_json_file(os.path.join(params.bert_model_dir, 'config.json'))
    model = BertForRE.from_pretrained(config=bert_config,
                                      pretrained_model_name_or_path=params.bert_model_dir,
                                      params=params)

    model.to(params.device)
    """
            input_ids: (batch_size, seq_len)
            attention_mask: (batch_size, seq_len)
            rel_tags: (bs, rel_num)
            potential_rels: (bs,), only in train stage.
            seq_tags: (bs, 2, seq_len)
            corres_tags: (bs, seq_len, seq_len)
            ex_params: experiment parameters
    """
    print(model)
    model.encoder.mamba.initialize_weights()

liguochun0304 avatar Jan 13 '25 05:01 liguochun0304

same problem cant find any fix

TheMakerOfWorlds avatar Jan 15 '25 05:01 TheMakerOfWorlds

I also have the same problem

Sheng-T avatar Mar 19 '25 11:03 Sheng-T

I also have the same problem我也有同样的问题

https://blog.csdn.net/m0_57145438/article/details/145114872

liguochun0304 avatar Mar 21 '25 03:03 liguochun0304

same problem cant find any fix同样的问题找不到任何解决方法

https://blog.csdn.net/m0_57145438/article/details/145114872

liguochun0304 avatar Mar 21 '25 03:03 liguochun0304