The inference results are incorrect
following example produce all true when sequence length is less than 30 but when it's above 30 it produce incorrect result in inference why. anyone know why it's like this ?
@torch.inference_mode() def run(): batch, length, dim = 2, 29, 16 x = torch.randn(batch, length, dim).to("cuda") model = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=16, # SSM state expansion factor d_conv=4, # Local convolution width expand=2, # Block expansion factor layer_idx=0, ).to("cuda")
# Training-style forward pass (full sequence in parallel)
y1 = model(x)
assert y1.shape == x.shape
# Inference-style forward pass (full sequence in parallel)
infer_params = InferenceParams(max_batch_size=batch, max_seqlen=length)
y2 = model(x, inference_params=infer_params)
# Inference-style forward pass (step by step using for loop)
infer_params = InferenceParams(max_batch_size=batch, max_seqlen=length)
outs = []
for i in range(length):
out = model(x[:, i : i + 1, :], inference_params=infer_params)
infer_params.seqlen_offset += 1
outs.append(out)
y3 = torch.cat(outs, 1)
print(torch.allclose(y1, y2)) # prints True
print(torch.allclose(y2, y3)) # prints True
print(torch.allclose(y1, y3)) # prints True
if name == 'main': run()
I had raised this issue a while back, #571 . Please let me know if you figure out anything.
I had raised this issue a while back, #571 . Please let me know if you figure out anything.
Sure, I will let you know if I can figure out. if you find solution also let me know please.
For Mamba2 this issue is much worse.
I had raised this issue a while back, #571 . Please let me know if you figure out anything.
Sure, I will let you know if I can figure out. if you find solution also let me know please.
Hi, I figured out the solution for my problem, but on a closer look, our problems are quite different. You can check my solution in the same issue, but I don't think it will be of much help. Sorry and thanks!