transformers
transformers copied to clipboard
Conversion Script for Mamba checkpoints (`mamba_ssm` -> `transformers`)
Feature request
Thanks very much for the Mamba support (#28094), this interoperability is fantastic!
I wanted to ask if there were any utility (doesn't have to be clean, just functional) for converting checkpoints provided for use in the mamba_ssm
library into the format provided in transformers
.
This would be very helpful if it exists! Thanks 🤗
Motivation
I'd like to be able to convert novel trained mamba models from the state-spaces/mamba
repo into HF transformers without rewriting a conversion script myself if need be.
Your contribution
I could write a utility for this if none exists but would probably not have the bandwidth to upstream it.
Hi @haileyschoelkopf, thanks for opening this feature request!
Normally we have conversion scripts under each model's folder (would be here).
I think adsding conversion script sounds like a good idea! @ArthurZucker contributed the model. He's off at the moment, but back soon - I'll let him reply in case there's a good reason we didn't add it alongside the model originally
Thank you! I'll share here if I get the time to implement one myself!
If it's okay/not too complicated, could I try and give this a shot (as a new outside contributor)?
Admittedly very new to ML stuff, but at a very high level, would this entail implementing conversion scripts similar to something like what's found in other model dirs such as https://github.com/huggingface/transformers/tree/b340d90738fa14bd6f81b65e4148173cbec62ff6/src/transformers/models/bert ?
I.e. Just 2 files for the forward and backwards pass convert_mamba_ssm_checkpoint_to_pytorch.py
and convert_pytorch_to_mamba_ssm_checkpoint.py
?
yep! Basically just load checkpoint file -> convert to a format loadable by the other library, e.g. reshaping or renaming weights as needed -> run load_state_dict()
using transformers
or mamba_ssm
model -> call save_pretrained()
to output new loadable/uploadable model.
I've made this if it's helpful for anyone. https://gist.github.com/SrGonao/33f373a13a6cad6b245450f3d6361598
Thanks, that's very helpful. I'll try to get a PR out soon modeled off that.
hey! The reason I did not add one is because the original checkpoints are compatible. This can still be added but only the config should be inferred / updated! So your checkpoint should already be compatible, no renaming and no reshaping
So your checkpoint should already be compatible, no renaming and no reshaping
Hm, when I try to run a conversion, I get an error suggesting there needs to be a rename:
RuntimeError: Error(s) in loading state_dict for MambaForCausalLM:
Missing key(s) in state_dict: "backbone.embeddings.weight".
Unexpected key(s) in state_dict: "backbone.embedding.weight".
The unexpected key contains "embedding" (with no s at the end), while the missing key contains "embeddings" (with an s at the end)
I've attempted to create a PR which both converts the config and does the above renaming for the forward pass: https://github.com/huggingface/transformers/pull/29705
(Huge thanks to @SrGonao , my PR does pretty much the same thing as his script, except on local files instead of interacting with the hub)
Oh no 😅 I am not getting this one on the original checkpoints, so maybe it was updated at some point?
Oh no 😅 I am not getting this one on the original checkpoints, so maybe it was updated at some point?
It might have, but I couldn't get it to work without the rename.
A quick printout of the original ssm model suggests at least the current version of mamba_ssm works with backbone.embedding.weight
(no s)
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
config = MambaConfig(d_model = 64, n_layer = 8)
model = MambaLMHeadModel(config)
print(model)
Outputs:
MambaLMHeadModel(
(backbone): MixerModel(
(embedding): Embedding(50280, 64)
(layers): ModuleList(
(0-7): 8 x Block(
(mixer): Mamba(
(in_proj): Linear(in_features=64, out_features=256, bias=False)
(conv1d): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
(act): SiLU()
(x_proj): Linear(in_features=128, out_features=36, bias=False)
(dt_proj): Linear(in_features=4, out_features=128, bias=True)
(out_proj): Linear(in_features=128, out_features=64, bias=False)
)
(norm): RMSNorm()
)
)
(norm_f): RMSNorm()
)
(lm_head): Linear(in_features=64, out_features=50280, bias=False)
)
You are right 😢 I have no idea why I did not get any warnings / maybe because the weights are tied it used the lm_head
's weights and tied using them.
Just saw your PR to remove the tie_weights that were forced before in ssm-state
, so let's try to fix this in transformers as well!
Arf that is really annoying
maybe because the weights are tied it used the lm_head's weights and tied using them.
I'm confused by what you mean here.
I thought that the problem was due to a difference in naming conventions between the two packages, where the transformers library and the mamba_ssm model library just chose to name their embedding layer differently.
Just saw your PR to remove the tie_weights that were forced before in ssm-state
Are you referring to @haileyschoelkopf's PR in https://github.com/state-spaces/mamba/pull/211?
I'll add a loading hook just this once! IMO should be the cleanest way to fix this
I mean that on my side, when implementing mamba in transformers I did not have a warning about the weights. I suppose that this is because the weights are by default tied. tie_word_embeddings=True
. Thus this:
>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280)
does not produce any warning. while:
>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280, tie_word_embeddings=False)
Some weights of MambaForCausalLM were not initialized from the model checkpoint at state-spaces/mamba-130m and are newly initialized: ['backbone.embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
does.
and:
>>> from transformers import AutoModel
>>> model = AutoModel.from_pretrained("state-spaces/mamba-130m", num_hidden_layers=24, vocab_size = 50280)
Some weights of MambaModel were not initialized from the model checkpoint at state-spaces/mamba-130m and are newly initialized: ['backbone.embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
That a really silent bug, and I got tricked by it...
cc @amyeroberts core issue !
Hm, just to make sure I understand:
- The lack of warning about uninitialized weights is because when
tie_word_embeddings=True
, the input embedding layer weights name is somehow ignored at some step of loading a pretrained model? - A converter which renames weights alongside config conversion should exist.
- This issue may exist in other models since it's due to behavior in weight tying and not something mamba specific?
-
More like the lack of an error being raised!
-
A converter which renames weights alongside config conversion should exist. "Could ", what I am thinking about is just a simple hook triggered in
from_pretrained
, let me open a PR in a bit. This avoids having a conversion script
3.This issue may exist in other models since it's due to behavior in weight tying and not something mamba specific? Totally. But most probably we would have heard of that, as people use the AutoModel api a lot, and don't always tie weights.
- Just wondering, if it's an error, should these be logged at warning level or are these generally just warnings: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L4011-L4030
- "This avoids having a conversion script" - Wouldn't there still need to be one to convert the
MambaConfig
? (If there doesn't need to be a converter at all, I guess I can discard https://github.com/huggingface/transformers/pull/29705) - If it causes issues immediately then yeah that makes sense too.
About 2. I mean avoid having to explicitly convert if you know the config. Config can be initialized first with from_pretrained!
But let's still have the conversion script! It will be beneficial to have a mapping between the names and the config explicitly!
For 1. yes a warning should indeed be issued sorry, we raise error for mismatch sizes!
Sg, I removed the weight rename from my PR (although now my PR won't actually work until yours is checked in)
Just a quick plug: I think your PR fixes the checkpointing issue, but PR https://github.com/huggingface/transformers/pull/29705 is still open for config->config conversion.
Indeed opening again!