mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Mamba2 Causality

Open William-HYWu opened this issue 9 months ago • 6 comments

Hi. Thank you for your wonderful work! I would like to inquire about the causality of Mamba2. I think theoretically it should be causal, however, when I run the code below:

import torch
from mamba_ssm import Mamba2

torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model = Mamba2(
                # This module uses roughly 3 * expand * d_model^2 parameters
                d_model=768,  # Model dimension d_model
                d_state=256,  # SSM state expansion factor
                rmsnorm=True,
                d_conv=4,  # Local convolution width
                expand=2,  # Block expansion factor
            )

model = model.cuda()
model.eval()
inputs = torch.randn(1, 128, 768).to(torch.device('cuda'))
outputs1 = model(inputs[:,:10,:])
outputs1 = outputs1.squeeze()
outputs2 = model(inputs)
outputs2 = outputs2.squeeze()[:10,:]
print(outputs1.shape)
print(outputs2.shape)
assert torch.equal(outputs1, outputs2), "Outputs are not equal"

I get AssertionError: Outputs are not equal

I have already excluded the factor of randomness since when running

import torch
from mamba_ssm import Mamba2

torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model = Mamba2(
                # This module uses roughly 3 * expand * d_model^2 parameters
                d_model=768,  # Model dimension d_model
                d_state=256,  # SSM state expansion factor
                rmsnorm=True,
                d_conv=4,  # Local convolution width
                expand=2,  # Block expansion factor
            )

model = model.cuda()
model.eval()
inputs = torch.randn(1, 128, 768).to(torch.device('cuda'))
outputs1 = model(inputs[:,:10,:])
outputs1 = outputs1.squeeze()
outputs2 = model(inputs[:,:10,:])
outputs2 = outputs2.squeeze()
print(outputs1.shape)
print(outputs2.shape)
assert torch.equal(outputs1, outputs2), "Outputs are not equal"

The assertion passed. Is there any extra argument I need to add to make it causal? Thank you for your help.

William-HYWu avatar Mar 05 '25 04:03 William-HYWu

How large is the difference?

tridao avatar Mar 05 '25 06:03 tridao

How large is the difference?

About 1e-7 to 1e-6, I'm suspecting it is due to some floating point precision rather than the model itself.

William-HYWu avatar Mar 05 '25 08:03 William-HYWu

Why would these be the same? The hidden states should be different after processing 10 items in the sequence. It is not a linear time-invariant system.

peterbjorgensen avatar Mar 05 '25 09:03 peterbjorgensen

Why would these be the same? The hidden states should be different after processing 10 items in the sequence. It is not a linear time-invariant system.

Yes, but I'm comparing the model's output rather than its hidden state. Since the model is causal, the first 10 outputs should remain the same regardless of sequence length, as no future information is used.

That said, the difference I observed is extremely small, which makes me inclined to believe the assertion failed due to inherent GPU precision variations rather than a fundamental issue with the model.

William-HYWu avatar Mar 05 '25 10:03 William-HYWu

Yes, sorry, I misread your code. That's interesting. Have you tried setting the chunk_size parameter to 1. Mamba splits the input into chunks and process them in parallel, then recombines them, so there might be some numerical noise depending on the chunking.

peterbjorgensen avatar Mar 05 '25 14:03 peterbjorgensen

How large is the difference?

About 1e-7 to 1e-6, I'm suspecting it is due to some floating point precision rather than the model itself.

That's probably fine.

tridao avatar Mar 05 '25 17:03 tridao