transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Conversion Script for Mamba checkpoints (`mamba_ssm` -> `transformers`)

Open haileyschoelkopf opened this issue 11 months ago • 10 comments

Feature request

Thanks very much for the Mamba support (#28094), this interoperability is fantastic!

I wanted to ask if there were any utility (doesn't have to be clean, just functional) for converting checkpoints provided for use in the mamba_ssm library into the format provided in transformers.

This would be very helpful if it exists! Thanks 🤗

Motivation

I'd like to be able to convert novel trained mamba models from the state-spaces/mamba repo into HF transformers without rewriting a conversion script myself if need be.

Your contribution

I could write a utility for this if none exists but would probably not have the bandwidth to upstream it.

haileyschoelkopf avatar Mar 13 '24 14:03 haileyschoelkopf

Hi @haileyschoelkopf, thanks for opening this feature request!

Normally we have conversion scripts under each model's folder (would be here).

I think adsding conversion script sounds like a good idea! @ArthurZucker contributed the model. He's off at the moment, but back soon - I'll let him reply in case there's a good reason we didn't add it alongside the model originally

amyeroberts avatar Mar 13 '24 15:03 amyeroberts

Thank you! I'll share here if I get the time to implement one myself!

haileyschoelkopf avatar Mar 13 '24 16:03 haileyschoelkopf

If it's okay/not too complicated, could I try and give this a shot (as a new outside contributor)?

Admittedly very new to ML stuff, but at a very high level, would this entail implementing conversion scripts similar to something like what's found in other model dirs such as https://github.com/huggingface/transformers/tree/b340d90738fa14bd6f81b65e4148173cbec62ff6/src/transformers/models/bert ?

I.e. Just 2 files for the forward and backwards pass convert_mamba_ssm_checkpoint_to_pytorch.py and convert_pytorch_to_mamba_ssm_checkpoint.py?

byi8220 avatar Mar 13 '24 17:03 byi8220

yep! Basically just load checkpoint file -> convert to a format loadable by the other library, e.g. reshaping or renaming weights as needed -> run load_state_dict() using transformers or mamba_ssm model -> call save_pretrained() to output new loadable/uploadable model.

haileyschoelkopf avatar Mar 13 '24 17:03 haileyschoelkopf

I've made this if it's helpful for anyone. https://gist.github.com/SrGonao/33f373a13a6cad6b245450f3d6361598

SrGonao avatar Mar 15 '24 19:03 SrGonao

Thanks, that's very helpful. I'll try to get a PR out soon modeled off that.

byi8220 avatar Mar 17 '24 05:03 byi8220

hey! The reason I did not add one is because the original checkpoints are compatible. This can still be added but only the config should be inferred / updated! So your checkpoint should already be compatible, no renaming and no reshaping

ArthurZucker avatar Mar 17 '24 23:03 ArthurZucker

So your checkpoint should already be compatible, no renaming and no reshaping

Hm, when I try to run a conversion, I get an error suggesting there needs to be a rename:

RuntimeError: Error(s) in loading state_dict for MambaForCausalLM:
        Missing key(s) in state_dict: "backbone.embeddings.weight". 
        Unexpected key(s) in state_dict: "backbone.embedding.weight". 

The unexpected key contains "embedding" (with no s at the end), while the missing key contains "embeddings" (with an s at the end)

I've attempted to create a PR which both converts the config and does the above renaming for the forward pass: https://github.com/huggingface/transformers/pull/29705

(Huge thanks to @SrGonao , my PR does pretty much the same thing as his script, except on local files instead of interacting with the hub)

byi8220 avatar Mar 17 '24 23:03 byi8220

Oh no 😅 I am not getting this one on the original checkpoints, so maybe it was updated at some point?

ArthurZucker avatar Mar 21 '24 06:03 ArthurZucker

Oh no 😅 I am not getting this one on the original checkpoints, so maybe it was updated at some point?

It might have, but I couldn't get it to work without the rename.

A quick printout of the original ssm model suggests at least the current version of mamba_ssm works with backbone.embedding.weight (no s)

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
config = MambaConfig(d_model = 64, n_layer = 8)
model = MambaLMHeadModel(config)
print(model)

Outputs:

MambaLMHeadModel(
  (backbone): MixerModel(
    (embedding): Embedding(50280, 64)
    (layers): ModuleList(
      (0-7): 8 x Block(
        (mixer): Mamba(
          (in_proj): Linear(in_features=64, out_features=256, bias=False)
          (conv1d): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
          (act): SiLU()
          (x_proj): Linear(in_features=128, out_features=36, bias=False)
          (dt_proj): Linear(in_features=4, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=64, bias=False)
        )
        (norm): RMSNorm()
      )
    )
    (norm_f): RMSNorm()
  )
  (lm_head): Linear(in_features=64, out_features=50280, bias=False)
)

byi8220 avatar Mar 21 '24 15:03 byi8220

You are right 😢 I have no idea why I did not get any warnings / maybe because the weights are tied it used the lm_head's weights and tied using them.

Just saw your PR to remove the tie_weights that were forced before in ssm-state, so let's try to fix this in transformers as well!

Arf that is really annoying

ArthurZucker avatar Mar 25 '24 09:03 ArthurZucker

maybe because the weights are tied it used the lm_head's weights and tied using them.

I'm confused by what you mean here.

I thought that the problem was due to a difference in naming conventions between the two packages, where the transformers library and the mamba_ssm model library just chose to name their embedding layer differently.

Just saw your PR to remove the tie_weights that were forced before in ssm-state

Are you referring to @haileyschoelkopf's PR in https://github.com/state-spaces/mamba/pull/211?

byi8220 avatar Mar 25 '24 10:03 byi8220

I'll add a loading hook just this once! IMO should be the cleanest way to fix this

ArthurZucker avatar Mar 25 '24 10:03 ArthurZucker

I mean that on my side, when implementing mamba in transformers I did not have a warning about the weights. I suppose that this is because the weights are by default tied. tie_word_embeddings=True. Thus this:

>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280)

does not produce any warning. while:

>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280, tie_word_embeddings=False)
Some weights of MambaForCausalLM were not initialized from the model checkpoint at state-spaces/mamba-130m and are newly initialized: ['backbone.embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

does.

and:

>>> from transformers import AutoModel

>>> model = AutoModel.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280)
Some weights of MambaModel were not initialized from the model checkpoint at state-spaces/mamba-130m and are newly initialized: ['backbone.embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

ArthurZucker avatar Mar 25 '24 10:03 ArthurZucker

That a really silent bug, and I got tricked by it...

ArthurZucker avatar Mar 25 '24 10:03 ArthurZucker

cc @amyeroberts core issue !

ArthurZucker avatar Mar 25 '24 10:03 ArthurZucker

Hm, just to make sure I understand:

  1. The lack of warning about uninitialized weights is because when tie_word_embeddings=True, the input embedding layer weights name is somehow ignored at some step of loading a pretrained model?
  2. A converter which renames weights alongside config conversion should exist.
  3. This issue may exist in other models since it's due to behavior in weight tying and not something mamba specific?

byi8220 avatar Mar 25 '24 12:03 byi8220

  1. More like the lack of an error being raised!

  2. A converter which renames weights alongside config conversion should exist. "Could ", what I am thinking about is just a simple hook triggered in from_pretrained, let me open a PR in a bit. This avoids having a conversion script

3.This issue may exist in other models since it's due to behavior in weight tying and not something mamba specific? Totally. But most probably we would have heard of that, as people use the AutoModel api a lot, and don't always tie weights.

ArthurZucker avatar Mar 25 '24 12:03 ArthurZucker

  1. Just wondering, if it's an error, should these be logged at warning level or are these generally just warnings: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L4011-L4030
  2. "This avoids having a conversion script" - Wouldn't there still need to be one to convert the MambaConfig? (If there doesn't need to be a converter at all, I guess I can discard https://github.com/huggingface/transformers/pull/29705)
  3. If it causes issues immediately then yeah that makes sense too.

byi8220 avatar Mar 25 '24 12:03 byi8220

About 2. I mean avoid having to explicitly convert if you know the config. Config can be initialized first with from_pretrained!

But let's still have the conversion script! It will be beneficial to have a mapping between the names and the config explicitly!

ArthurZucker avatar Mar 25 '24 12:03 ArthurZucker

For 1. yes a warning should indeed be issued sorry, we raise error for mismatch sizes!

ArthurZucker avatar Mar 25 '24 12:03 ArthurZucker

Sg, I removed the weight rename from my PR (although now my PR won't actually work until yours is checked in)

byi8220 avatar Mar 25 '24 12:03 byi8220

Just a quick plug: I think your PR fixes the checkpointing issue, but PR https://github.com/huggingface/transformers/pull/29705 is still open for config->config conversion.

byi8220 avatar Mar 28 '24 14:03 byi8220

Indeed opening again!

ArthurZucker avatar Mar 28 '24 14:03 ArthurZucker