mamba
mamba copied to clipboard
The output of forward and step are different
I test forward and step function with same input_ids and intermediate states. But the output logits and states are quite different.
The code is like this:
for i in range(seqlen):
hidden_state1,conv_state,ssm_state = model.step(input_ids[i],conv_state, ssm_state)
hidden_state2 = model(input_ids,inference_param)
I think the result should be same, but I do not know the reason.