Mamba2 Causality
Hi. Thank you for your wonderful work! I would like to inquire about the causality of Mamba2. I think theoretically it should be causal, however, when I run the code below:
import torch
from mamba_ssm import Mamba2
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=768, # Model dimension d_model
d_state=256, # SSM state expansion factor
rmsnorm=True,
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
)
model = model.cuda()
model.eval()
inputs = torch.randn(1, 128, 768).to(torch.device('cuda'))
outputs1 = model(inputs[:,:10,:])
outputs1 = outputs1.squeeze()
outputs2 = model(inputs)
outputs2 = outputs2.squeeze()[:10,:]
print(outputs1.shape)
print(outputs2.shape)
assert torch.equal(outputs1, outputs2), "Outputs are not equal"
I get AssertionError: Outputs are not equal
I have already excluded the factor of randomness since when running
import torch
from mamba_ssm import Mamba2
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=768, # Model dimension d_model
d_state=256, # SSM state expansion factor
rmsnorm=True,
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
)
model = model.cuda()
model.eval()
inputs = torch.randn(1, 128, 768).to(torch.device('cuda'))
outputs1 = model(inputs[:,:10,:])
outputs1 = outputs1.squeeze()
outputs2 = model(inputs[:,:10,:])
outputs2 = outputs2.squeeze()
print(outputs1.shape)
print(outputs2.shape)
assert torch.equal(outputs1, outputs2), "Outputs are not equal"
The assertion passed. Is there any extra argument I need to add to make it causal? Thank you for your help.
How large is the difference?
How large is the difference?
About 1e-7 to 1e-6, I'm suspecting it is due to some floating point precision rather than the model itself.
Why would these be the same? The hidden states should be different after processing 10 items in the sequence. It is not a linear time-invariant system.
Why would these be the same? The hidden states should be different after processing 10 items in the sequence. It is not a linear time-invariant system.
Yes, but I'm comparing the model's output rather than its hidden state. Since the model is causal, the first 10 outputs should remain the same regardless of sequence length, as no future information is used.
That said, the difference I observed is extremely small, which makes me inclined to believe the assertion failed due to inherent GPU precision variations rather than a fundamental issue with the model.
Yes, sorry, I misread your code. That's interesting. Have you tried setting the chunk_size parameter to 1. Mamba splits the input into chunks and process them in parallel, then recombines them, so there might be some numerical noise depending on the chunking.
How large is the difference?
About 1e-7 to 1e-6, I'm suspecting it is due to some floating point precision rather than the model itself.
That's probably fine.