mamba icon indicating copy to clipboard operation
mamba copied to clipboard

The inference results are incorrect

Open shahidalihakro opened this issue 11 months ago • 4 comments

following example produce all true when sequence length is less than 30 but when it's above 30 it produce incorrect result in inference why. anyone know why it's like this ?

@torch.inference_mode() def run(): batch, length, dim = 2, 29, 16 x = torch.randn(batch, length, dim).to("cuda") model = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=16, # SSM state expansion factor d_conv=4, # Local convolution width expand=2, # Block expansion factor layer_idx=0, ).to("cuda")

# Training-style forward pass (full sequence in parallel)
y1 = model(x)
assert y1.shape == x.shape

# Inference-style forward pass (full sequence in parallel)
infer_params = InferenceParams(max_batch_size=batch, max_seqlen=length)
y2 = model(x, inference_params=infer_params)

# Inference-style forward pass (step by step using for loop)
infer_params = InferenceParams(max_batch_size=batch, max_seqlen=length)
outs = []
for i in range(length):
    out = model(x[:, i : i + 1, :], inference_params=infer_params)
    infer_params.seqlen_offset += 1
    outs.append(out)
y3 = torch.cat(outs, 1)

print(torch.allclose(y1, y2))  # prints True
print(torch.allclose(y2, y3))  # prints True
print(torch.allclose(y1, y3))  # prints True

if name == 'main': run()

shahidalihakro avatar Feb 02 '25 08:02 shahidalihakro

I had raised this issue a while back, #571 . Please let me know if you figure out anything.

karannb avatar Feb 02 '25 18:02 karannb

I had raised this issue a while back, #571 . Please let me know if you figure out anything.

Sure, I will let you know if I can figure out. if you find solution also let me know please.

shahidalihakro avatar Feb 03 '25 03:02 shahidalihakro

For Mamba2 this issue is much worse.

peterbjorgensen avatar Apr 09 '25 10:04 peterbjorgensen

I had raised this issue a while back, #571 . Please let me know if you figure out anything.

Sure, I will let you know if I can figure out. if you find solution also let me know please.

Hi, I figured out the solution for my problem, but on a closer look, our problems are quite different. You can check my solution in the same issue, but I don't think it will be of much help. Sorry and thanks!

karannb avatar Apr 15 '25 21:04 karannb