mamba icon indicating copy to clipboard operation
mamba copied to clipboard

A poorman mamba code

Open buttercutter opened this issue 1 year ago • 12 comments
trafficstars

I have very bad training loss for a simple mamba code implementation , may I know why ?

100%|██████████| 10/10 [03:03<00:00, 18.32s/it]

Epoch: 10, Training Loss: -6515147.2516, Validation Loss: -7471518.3141, Validation Perplexity: 0.0000

buttercutter avatar Dec 13 '23 14:12 buttercutter

Have you plugged in a standard Transformer first? It seems more likely that there's something wrong with the training pipeline than with any particular model.

albertfgu avatar Dec 13 '23 16:12 albertfgu

I had plugged in a small bert model, and the training works alright, so I am not really sure what else is missing from my MAMBA architecture module.

Please advise.

buttercutter avatar Dec 20 '23 15:12 buttercutter

It looks like you reimplemented the model from scratch, so this is beyond the scope of our ability to help. Perhaps check line by line that your implementation matches ours?

albertfgu avatar Dec 20 '23 17:12 albertfgu

I tried to compare my code with your code as well as @johnma2006 's code line-by-line, taking three code files in perspective, there seems to be no successful findings so far, except for delta inverse-softplus initialization which only your code had performed.

I am bit stucked, Please advise.

buttercutter avatar Jan 04 '24 08:01 buttercutter

Hi, here is a suggestion is to check the correctness of your implementation:

  1. Load an instance of your implementation and the official implementation side-by-side.
  2. Transfer the official instance's weights into your instance.
  3. Make sure the forward is identical. If not, drill down into each submodule to see where the diffs are coming from.

Good luck!

johnma2006 avatar Jan 05 '24 05:01 johnma2006

I tried to compare my code with your code as well as @johnma2006 's code line-by-line, taking three code files in perspective, there seems to be no successful findings so far, except for delta inverse-softplus initialization which only your code had performed.

I am bit stucked, Please advise.

The delta initialization is important.

tridao avatar Jan 05 '24 05:01 tridao

Comment on the initialization and parameterization: They are super important in the sense that without the suitable initialization and parameterization, the learning of long-term memory with SSMs can be unstable thus difficult. (https://arxiv.org/abs/2311.14495)

radarFudan avatar Jan 07 '24 12:01 radarFudan

Thanks for the comments.

I had already incorporated proper delta initialization into the mamba code, but it is not helping with training loss convergence issue yet.

I need to think from other angle perspectives. :eyes:

image

@radarFudan : I noticed that StableSSM tries to constraint the growth rate of gradient by constraining the eigenvalues. This approach seems to complement the operations done by clip_grad_norm(). I will give StableSSM a go in my code implementation, will post further updates here, thanks !!

buttercutter avatar Jan 09 '24 16:01 buttercutter

The stable SSM initializations may or may not help, we've never tried them. But I think the theory doesn't apply directly to the selective SSM setting. I don't think there should be anything particular that you need to do here, so either there's an issue in the implementation or somehow Mamba interacts with your data weirdly, which would be interesting.

  1. Have you checked that your mamba function returns the same outputs as ours, as @johnma2006 suggested?
  2. Is there any reason you can't directly call the model from this repository? Is the purpose of your model expository or for research?

albertfgu avatar Jan 09 '24 18:01 albertfgu

My Mamba code implementation seems to work now without any negative training loss for now so far, will do further checking and code regression running to see if the issue still persists.

buttercutter avatar Jan 17 '24 11:01 buttercutter

Great! What did you change?

albertfgu avatar Jan 17 '24 16:01 albertfgu

@albertfgu : One of the major change is the output sizing which has to be related to vocab_size instead of d_model

See all the other changes required for getting rid of negative training loss

I will update the new code in github instead of gist now.

buttercutter avatar Jan 20 '24 03:01 buttercutter