Can't train mamba2 from scratch with HF Trainer
I'm trying to train mamba2 130m from scratch.
config = Mamba2Config(
vocab_size=len(tokenizer.vocab),
n_positions=10,
n_embd=768,
n_layer=12,
n_head=12,
n_inner=3072,
)
model = Mamba2ForCausalLM(config)
training_args = TrainingArguments(
output_dir=args.output_dir,
logging_dir='./logs',
gradient_accumulation_steps=1,
save_steps=2000,
max_steps=1500000,
evaluation_strategy="steps",
eval_steps=2000,
#prediction_loss_only=True,
logging_strategy="epoch",
#logging_steps=500,
learning_rate=1e-4,
# save_total_limit=999999,
fp16=True,
dataloader_num_workers=4,
per_device_train_batch_size=512,
per_device_eval_batch_size=512,
lr_scheduler_type="constant_with_warmup",
weight_decay=0.1,
warmup_steps=2000,
)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_eval
)
trainer.train()
Error:
File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 542, in <module>
main()
File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 512, in main
trainer.train()
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
return inner_training_loop(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step
loss = self.compute_loss(model, inputs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss
outputs = model(**inputs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
return model_forward(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
return func(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 1048, in forward
mamba2_outputs = self.backbone(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 900, in forward
hidden_states = mixer_block(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 648, in forward
hidden_states = self.mixer(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 607, in forward
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 531, in torch_forward
L = torch.exp(segment_sum(A))
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 105, in segment_sum
tensor_segsum = torch.cumsum(input_tensor, dim=-2)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 GiB. GPU 0 has a total capacity of 79.15 GiB of which 3.70 GiB is free. Including non-PyTorch memory, this process has 75.44 GiB memory in use. Of the allocated memory 74.80 GiB is allocated by PyTorch, and 139.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.
where,
GPU 0: NVIDIA A100-SXM4-80GB
GPU Utilization: 0%
Memory Utilization: 0%
Total Memory: 81920.00 MB
Free Memory: 52246.00 MB
Used Memory: 29674.00 MB
You're missing either the mamba-ssm package and/or causal-conv1d package:
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) <-- this is the slow path, unoptimized and using a lot of memory
Hf only opts for the fast path, i.e. the kernels, if both are available (there should've been a warning). So consider installing the packages. For the future, these type of questions are more appropriate on the transformers repo ;)
Thanks for the response. I agree that it is more appropriate to post on the tranformers repo. I'll make sure to do that next time.
I've tried it with fast path too. I don't see this error when I work with Mamba with same hyperparams (instead of mamba2).
out, ssm_state = mamba_split_conv1d_scan_combined(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 930, in mamba_split_conv1d_scan_combined
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 455, in decorate_fwd
return fwd(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 795, in forward
out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 317, in _mamba_chunk_scan_combined_fwd
states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_state_passing.py", line 205, in _state_passing_fwd
out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
This seems like a partial stack trace, no? Could you share the complete one?
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py:164: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py:240: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
def backward(ctx, dout):
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py:986: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
def forward(
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py:1045: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
def backward(ctx, dout, *args):
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/distributed/tensor_parallel.py:26: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/distributed/tensor_parallel.py:62: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
def backward(ctx, grad_output):
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py:758: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py:836: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
def backward(ctx, dout, *args):
Traceback (most recent call last):
File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 575, in <module>
main()
File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 486, in main
model.to(device)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2905, in to
return super().to(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1174, in to
return self._apply(convert)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 780, in _apply
module._apply(fn)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 780, in _apply
module._apply(fn)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 780, in _apply
module._apply(fn)
[Previous line repeated 2 more times]
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 805, in _apply
param_applied = fn(param)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1160, in convert
return t.to(
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 MiB. GPU 0 has a total capacity of 15.77 GiB of which 125.12 MiB is free. Including non-PyTorch memory, this process has 15.65 GiB memory in use. Of the allocated memory 15.35 GiB is allocated by PyTorch, and 1.23 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
It seems to me that with your config initialization, you pass the wrong arguments:
config = Mamba2Config(
vocab_size=len(tokenizer.vocab),
n_positions=10,
n_embd=768,
n_layer=12,
n_head=12,
n_inner=3072,
)
For example, there is no argument n_layer so you will get the defaults of codestral mamba (a 7b model) and then it doesn't surprise me that you won't have enough memory. See https://github.com/huggingface/transformers/blob/main/src/transformers/models/mamba2/configuration_mamba2.py for the config file. You could also get the defaults from the conversion script in source https://github.com/huggingface/transformers/blob/main/src/transformers/models/mamba2/convert_mamba2_ssm_checkpoint_to_pytorch.py (and inspect the values).
Also this time it's a different GPU? Looks like 16GB total RAM instead of 80.
I'll refer to the config, Thanks.
config = AutoConfig.from_pretrained('state-spaces/mamba-130m')
model = MambaForCausalLM(config)
Works fine.
Mamba2 doesn't though, when ran with,
config = AutoConfig.from_pretrained('state-spaces/mamba2-130m')
model = Mamba2ForCausalLM(config)
facing,
File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 575, in <module>
main()
File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 545, in main
trainer.train()
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
return inner_training_loop(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3349, in training_step
self.accelerator.backward(loss, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/accelerate/accelerator.py", line 2155, in backward
self.scaler.scale(loss).backward(**kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
torch.autograd.backward(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
_engine_run_backward(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/autograd/function.py", line 306, in apply
return user_fn(self, *args)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 501, in decorate_bwd
return bwd(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 893, in backward
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 381, in _mamba_chunk_scan_combined_bwd
states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_state_passing.py", line 205, in _state_passing_fwd
out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.15 GiB of which 810.00 MiB is free. Including non-PyTorch memory, this process has 78.35 GiB memory in use. Of the allocated memory 76.22 GiB is allocated by PyTorch, and 1.62 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
on a 80gb vram (a100).
Those configs are not compatible with hf transformers. For mamba(1) there exist compatible variants with "-hf" at the end, e.g. "state-spaces/mamba-130m-hf".
I'd assume that the default config values would be loaded instead then which are 2.7b and 7b models respectively for mamba(1) and mamba(2). Then you will more easily encounter RAM issues with mamba(2). You should be able to check a hf model parameter size with .num_parameters().
The previous link to the config was meant for the entries (that are possible), but since hf and here have different conventions it might be hard to see what to change. The conversion script might help there since you can use it to convert a model from here to hf format - this way you pretty much have the correct default values (for the config).
I understand that the configurations are not directly compatible with HuggingFace, and that using the conversion script can help obtain the correct default values for the configuration.
I’m having difficulty locating the Mamba2 130M checkpoint files necessary to use the conversion script. Could you please guide me on how to obtain these checkpoint files? Are they available for download, or is there a repository where I can access them?
Ah sure. You'll need git lfs, so potentially call git lfs install before. Afterwards, you could just clone the original repository at mamba2-130m, i.e. git clone https://huggingface.co/state-spaces/mamba2-130m.
If you want to wait a bit, I can upload the converted weights in fp32?
Sure, Thanks. Please give me an update when you did so.
@npkanaka
Added a repo at https://huggingface.co/AntonV/mamba2-130m-hf , you can just get config via AutoConfig.from_pretrained('AntonV/mamba2-130m-hf') now.
Thanks for providing the Mamba2 130M configuration and weights.
- I’m training the 130M model from scratch on an A100, with the same setup and parameters as GPT2. I’ve noticed that Mamba2 130M is taking about three times longer to train than GPT2 (150 hours vs. 50 hours). Is this longer training time expected due to architectural differences, or is there something I might be missing to optimize training speed for Mamba2?
- Using HF Trainer,
RuntimeError: The weights trying to be saved contained shared tensors [{'backbone.embeddings.weight', 'lm_head.weight'}] that are mismatching the transformers base configuration. Try saving usingsafe_serialization=Falseor remove this tensor sharing.Is avoiding weight sharing acceptable?config.tie_word_embeddings = False
No problem.
- This seems a bit unlikely tbh. Have you ensured that
mamba-ssmandcausal-conv1dare installed? Maybe setconfig.use_cache=Falseduring training at least. Otherwise, I'm out of quick options. Mamba2 is also more suitable for bigger architectures and longer sequences so it might be faster in comparison when using 1b params + 4k seq lengths for example. - Can you pass
safe_serialization=Falsewhen saving your model? Tbh, weight sharing is the standard in Mamba2 so you should not disable it. But, tied embeddings are a mixed thing; it might be beneficial or not.
@vasqu do you perhaps know why the huggingface model breaks when you change the hidden_size in the config? Does it have to be a multiple of another parameter?
@Jellymoon
Only thing I'm aware of would be this line: hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim. expand is usually 2, so you need to consider that (2 * hidden_size) % head_dim == 0 .
Otherwise, this might also be relevant for the kernels: https://github.com/state-spaces/mamba/issues/352#issuecomment-2167093395.
And if you have a small reproducer or something, that would be ideal.
@vasqu I tried changing the expanding factor and if I scale it relative to the hidden size it works for example:
hidden_size=2048, expand=4
hidden_size=4096, expand=2
hidden_size=8192, expand=1
What does not work is changing the hidden_size without changing the expanding factor. Even if you half or double the default hidden_size, (2 * hidden_size) % head_dim == 0 should still be satisfied I think?
Here is what I have been trying on an RTX 3090, I was trying to make a smaller model:
import torch
import time
from transformers import AutoTokenizer,Mamba2Model,Mamba2Config, MambaConfig, MambaModel
from tqdm import tqdm
# config = MambaConfig(
config = Mamba2Config(
vocab_size=len(tokenizer.vocab),
state_size=128,
num_heads = 128,
head_dim = 64,
num_hidden_layers=4,
hidden_size = 2048,
# expand=4
)
# model = MambaModel(config).to(torch.device("cuda"))
model = Mamba2Model(config).to(torch.device("cuda"))
t = time.process_time()
for i in tqdm(range(1000)):
inputs = {'input_ids': torch.randint(0, len(tokenizer), (1,4096), device=torch.device("cuda"))}
outputs = model(**inputs, labels=inputs["input_ids"])
elapsed_time = time.process_time() - t
print(elapsed_time)
You also have to pass the correct number of heads, i.e.:
import torch
from transformers import AutoTokenizer, Mamba2Model, Mamba2Config
hidden_size = 2048
expand = 2
head_dim = 64
# this here is the key to pass too, not very nice to have it passed in the config imo but that's the current state now :(
num_heads = (hidden_size * expand) // head_dim
config = Mamba2Config(
vocab_size=len(tokenizer.vocab),
state_size=128,
num_heads=num_heads,
head_dim=head_dim,
num_hidden_layers=4,
hidden_size=hidden_size,
expand=expand,
# I'd suggest using this as it's default for all models in the mamba-ssm repo
n_groups=1,
)
...
The scaling worked because the effective size didn't change @Jellymoon
Thank you for the help!
I had a couple of questions regarding the model's default settings in the Huggingface configuration:
From the paper:
Is the no_bias_terms parameter set to True by default in the Huggingface version of Mamba2?
Does the configuration use RMSNorm instead of LayerNorm by default, or is there a specific way to set this manually?
I am trying to find the best hyperparameters that you've run experiments on the top of AutoConfig.from_pretrained('AntonV/mamba2-130m-hf').
@npkanaka There are two bias parameters:
- Bias for the linear projections which is set to False by default ( https://github.com/huggingface/transformers/blob/d6ba1ac041ac0b07bc589dd82a67cfb76f75d0f9/src/transformers/models/mamba2/configuration_mamba2.py#L62 )
- Bias for the convolution which is set to True by default ( https://github.com/huggingface/transformers/blob/d6ba1ac041ac0b07bc589dd82a67cfb76f75d0f9/src/transformers/models/mamba2/configuration_mamba2.py#L64 )
There is a parameter that suggests you could use layernorm instead of RMS norm in hf but it's sadly false / misleading. RMS norm will always be applied and tbh, it shouldn't matter too much.
No problem.
- This seems a bit unlikely tbh. Have you ensured that
mamba-ssmandcausal-conv1dare installed? Maybe setconfig.use_cache=Falseduring training at least. Otherwise, I'm out of quick options. Mamba2 is also more suitable for bigger architectures and longer sequences so it might be faster in comparison when using 1b params + 4k seq lengths for example.- Can you pass
safe_serialization=Falsewhen saving your model? Tbh, weight sharing is the standard in Mamba2 so you should not disable it. But, tied embeddings are a mixed thing; it might be beneficial or not.
why set use_cache=False could speed up training though? I have encounter the same question with max_length=100
@Lynnzake At least in hf, you guarantee that inference is avoided 100% and in that case the code opts for the fused path, i.e. a kernel for conv combined with the mamba op. Anything "kernelized" is faster than the base ops one after the other (even if they are one kernel each).
Not sure what you mean with max_len but the mamba2 kernel at least is pretty unoptimized in regards to smaller seq lens. The speed only gets pretty good at lengths 2048<=
@Lynnzake At least in hf, you guarantee that inference is avoided 100% and in that case the code opts for the fused path, i.e. a kernel for conv combined with the mamba op. Anything "kernelized" is faster than the base ops one after the other (even if they are one kernel each).
Not sure what you mean with max_len but the mamba2 kernel at least is pretty unoptimized in regards to smaller seq lens. The speed only gets pretty good at lengths 2048<=
Appreciate your reply, max_len is the inputs length on my models.
Then the second part of my reply should be correct (?), i.e. max len should be 2048 or longer to be good on speed.
Then the second part of my reply should be correct (?), i.e. max len should be 2048 or longer to be good on speed.
Yes.