mamba
mamba copied to clipboard
A poorman mamba code
I have very bad training loss for a simple mamba code implementation , may I know why ?
100%|██████████| 10/10 [03:03<00:00, 18.32s/it]
Epoch: 10, Training Loss: -6515147.2516, Validation Loss: -7471518.3141, Validation Perplexity: 0.0000
Have you plugged in a standard Transformer first? It seems more likely that there's something wrong with the training pipeline than with any particular model.
I had plugged in a small bert model, and the training works alright, so I am not really sure what else is missing from my MAMBA architecture module.
Please advise.
It looks like you reimplemented the model from scratch, so this is beyond the scope of our ability to help. Perhaps check line by line that your implementation matches ours?
I tried to compare my code with your code as well as @johnma2006 's code line-by-line, taking three code files in perspective, there seems to be no successful findings so far, except for delta inverse-softplus initialization which only your code had performed.
I am bit stucked, Please advise.
Hi, here is a suggestion is to check the correctness of your implementation:
- Load an instance of your implementation and the official implementation side-by-side.
- Transfer the official instance's weights into your instance.
- Make sure the forward is identical. If not, drill down into each submodule to see where the diffs are coming from.
Good luck!
I tried to compare my code with your code as well as @johnma2006 's code line-by-line, taking three code files in perspective, there seems to be no successful findings so far, except for
deltainverse-softplus initialization which only your code had performed.I am bit stucked, Please advise.
The delta initialization is important.
Comment on the initialization and parameterization: They are super important in the sense that without the suitable initialization and parameterization, the learning of long-term memory with SSMs can be unstable thus difficult. (https://arxiv.org/abs/2311.14495)
Thanks for the comments.
I had already incorporated proper delta initialization into the mamba code, but it is not helping with training loss convergence issue yet.
I need to think from other angle perspectives. :eyes:
@radarFudan : I noticed that StableSSM tries to constraint the growth rate of gradient by constraining the eigenvalues. This approach seems to complement the operations done by clip_grad_norm(). I will give StableSSM a go in my code implementation, will post further updates here, thanks !!
The stable SSM initializations may or may not help, we've never tried them. But I think the theory doesn't apply directly to the selective SSM setting. I don't think there should be anything particular that you need to do here, so either there's an issue in the implementation or somehow Mamba interacts with your data weirdly, which would be interesting.
- Have you checked that your mamba function returns the same outputs as ours, as @johnma2006 suggested?
- Is there any reason you can't directly call the model from this repository? Is the purpose of your model expository or for research?
My Mamba code implementation seems to work now without any negative training loss for now so far, will do further checking and code regression running to see if the issue still persists.
Great! What did you change?
@albertfgu : One of the major change is the output sizing which has to be related to vocab_size instead of d_model
See all the other changes required for getting rid of negative training loss
I will update the new code in github instead of gist now.