mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Can't train mamba2 from scratch with HF Trainer

Open npkanaka opened this issue 1 year ago • 19 comments

I'm trying to train mamba2 130m from scratch.

config = Mamba2Config(
            vocab_size=len(tokenizer.vocab),
            n_positions=10,
            n_embd=768,              
            n_layer=12,               
            n_head=12,             
            n_inner=3072, 
        )
        model = Mamba2ForCausalLM(config)
    training_args = TrainingArguments(
    output_dir=args.output_dir,
    logging_dir='./logs',
    gradient_accumulation_steps=1,  
    save_steps=2000,
    max_steps=1500000, 
    evaluation_strategy="steps",
    eval_steps=2000,
    #prediction_loss_only=True,
    logging_strategy="epoch",
    #logging_steps=500,
    learning_rate=1e-4,
    # save_total_limit=999999,
    fp16=True,
    dataloader_num_workers=4,
    per_device_train_batch_size=512,
    per_device_eval_batch_size=512, 
    lr_scheduler_type="constant_with_warmup",
    weight_decay=0.1,
    warmup_steps=2000,
    )
    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval
    )
trainer.train()

Error:

File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 542, in <module>
    main()
  File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 512, in main
    trainer.train()
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3318, in training_step
    loss = self.compute_loss(model, inputs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss
    outputs = model(**inputs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 1048, in forward
    mamba2_outputs = self.backbone(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 900, in forward
    hidden_states = mixer_block(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 648, in forward
    hidden_states = self.mixer(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 607, in forward
    return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 531, in torch_forward
    L = torch.exp(segment_sum(A))
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.py", line 105, in segment_sum
    tensor_segsum = torch.cumsum(input_tensor, dim=-2)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 GiB. GPU 0 has a total capacity of 79.15 GiB of which 3.70 GiB is free. Including non-PyTorch memory, this process has 75.44 GiB memory in use. Of the allocated memory 74.80 GiB is allocated by PyTorch, and 139.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.

where,

GPU 0: NVIDIA A100-SXM4-80GB
  GPU Utilization: 0%
  Memory Utilization: 0%
  Total Memory: 81920.00 MB
  Free Memory: 52246.00 MB
  Used Memory: 29674.00 MB

npkanaka avatar Sep 06 '24 20:09 npkanaka

You're missing either the mamba-ssm package and/or causal-conv1d package: return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) <-- this is the slow path, unoptimized and using a lot of memory

Hf only opts for the fast path, i.e. the kernels, if both are available (there should've been a warning). So consider installing the packages. For the future, these type of questions are more appropriate on the transformers repo ;)

vasqu avatar Sep 06 '24 20:09 vasqu

Thanks for the response. I agree that it is more appropriate to post on the tranformers repo. I'll make sure to do that next time.

I've tried it with fast path too. I don't see this error when I work with Mamba with same hyperparams (instead of mamba2).

    out, ssm_state = mamba_split_conv1d_scan_combined(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 930, in mamba_split_conv1d_scan_combined
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 455, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 795, in forward
    out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 317, in _mamba_chunk_scan_combined_fwd
    states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_state_passing.py", line 205, in _state_passing_fwd
    out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)

npkanaka avatar Sep 07 '24 14:09 npkanaka

This seems like a partial stack trace, no? Could you share the complete one?

vasqu avatar Sep 08 '24 23:09 vasqu

/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py:164: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py:240: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout):
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py:986: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/layer_norm.py:1045: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout, *args):
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/distributed/tensor_parallel.py:26: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/distributed/tensor_parallel.py:62: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, grad_output):
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py:758: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py:836: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, dout, *args):
Traceback (most recent call last):
  File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 575, in <module>
    main()
  File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 486, in main
    model.to(device)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2905, in to
    return super().to(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1174, in to
    return self._apply(convert)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 780, in _apply
    module._apply(fn)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 780, in _apply
    module._apply(fn)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 780, in _apply
    module._apply(fn)
  [Previous line repeated 2 more times]
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 805, in _apply
    param_applied = fn(param)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1160, in convert
    return t.to(
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 MiB. GPU 0 has a total capacity of 15.77 GiB of which 125.12 MiB is free. Including non-PyTorch memory, this process has 15.65 GiB memory in use. Of the allocated memory 15.35 GiB is allocated by PyTorch, and 1.23 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

npkanaka avatar Sep 10 '24 15:09 npkanaka

It seems to me that with your config initialization, you pass the wrong arguments:

config = Mamba2Config(
            vocab_size=len(tokenizer.vocab),
            n_positions=10,
            n_embd=768,              
            n_layer=12,               
            n_head=12,             
            n_inner=3072, 
        )

For example, there is no argument n_layer so you will get the defaults of codestral mamba (a 7b model) and then it doesn't surprise me that you won't have enough memory. See https://github.com/huggingface/transformers/blob/main/src/transformers/models/mamba2/configuration_mamba2.py for the config file. You could also get the defaults from the conversion script in source https://github.com/huggingface/transformers/blob/main/src/transformers/models/mamba2/convert_mamba2_ssm_checkpoint_to_pytorch.py (and inspect the values).

Also this time it's a different GPU? Looks like 16GB total RAM instead of 80.

vasqu avatar Sep 10 '24 16:09 vasqu

I'll refer to the config, Thanks.

        config = AutoConfig.from_pretrained('state-spaces/mamba-130m')
        model = MambaForCausalLM(config)

Works fine.

Mamba2 doesn't though, when ran with,

        config = AutoConfig.from_pretrained('state-spaces/mamba2-130m')
        model = Mamba2ForCausalLM(config)

facing,

File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 575, in <module>
    main()
  File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 545, in main
    trainer.train()
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3349, in training_step
    self.accelerator.backward(loss, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/accelerate/accelerator.py", line 2155, in backward
    self.scaler.scale(loss).backward(**kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/autograd/function.py", line 306, in apply
    return user_fn(self, *args)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 501, in decorate_bwd
    return bwd(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 893, in backward
    dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 381, in _mamba_chunk_scan_combined_bwd
    states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_state_passing.py", line 205, in _state_passing_fwd
    out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.15 GiB of which 810.00 MiB is free. Including non-PyTorch memory, this process has 78.35 GiB memory in use. Of the allocated memory 76.22 GiB is allocated by PyTorch, and 1.62 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

on a 80gb vram (a100).

npkanaka avatar Sep 11 '24 01:09 npkanaka

Those configs are not compatible with hf transformers. For mamba(1) there exist compatible variants with "-hf" at the end, e.g. "state-spaces/mamba-130m-hf".

I'd assume that the default config values would be loaded instead then which are 2.7b and 7b models respectively for mamba(1) and mamba(2). Then you will more easily encounter RAM issues with mamba(2). You should be able to check a hf model parameter size with .num_parameters().

The previous link to the config was meant for the entries (that are possible), but since hf and here have different conventions it might be hard to see what to change. The conversion script might help there since you can use it to convert a model from here to hf format - this way you pretty much have the correct default values (for the config).

vasqu avatar Sep 11 '24 07:09 vasqu

I understand that the configurations are not directly compatible with HuggingFace, and that using the conversion script can help obtain the correct default values for the configuration.

I’m having difficulty locating the Mamba2 130M checkpoint files necessary to use the conversion script. Could you please guide me on how to obtain these checkpoint files? Are they available for download, or is there a repository where I can access them?

npkanaka avatar Sep 16 '24 03:09 npkanaka

Ah sure. You'll need git lfs, so potentially call git lfs install before. Afterwards, you could just clone the original repository at mamba2-130m, i.e. git clone https://huggingface.co/state-spaces/mamba2-130m.

If you want to wait a bit, I can upload the converted weights in fp32?

vasqu avatar Sep 16 '24 14:09 vasqu

Sure, Thanks. Please give me an update when you did so.

npkanaka avatar Sep 16 '24 14:09 npkanaka

@npkanaka

Added a repo at https://huggingface.co/AntonV/mamba2-130m-hf , you can just get config via AutoConfig.from_pretrained('AntonV/mamba2-130m-hf') now.

vasqu avatar Sep 16 '24 16:09 vasqu

Thanks for providing the Mamba2 130M configuration and weights.

  1. I’m training the 130M model from scratch on an A100, with the same setup and parameters as GPT2. I’ve noticed that Mamba2 130M is taking about three times longer to train than GPT2 (150 hours vs. 50 hours). Is this longer training time expected due to architectural differences, or is there something I might be missing to optimize training speed for Mamba2?
  2. Using HF Trainer, RuntimeError: The weights trying to be saved contained shared tensors [{'backbone.embeddings.weight', 'lm_head.weight'}] that are mismatching the transformers base configuration. Try saving using safe_serialization=Falseor remove this tensor sharing. Is avoiding weight sharing acceptable? config.tie_word_embeddings = False

npkanaka avatar Sep 18 '24 22:09 npkanaka

No problem.

  1. This seems a bit unlikely tbh. Have you ensured that mamba-ssm and causal-conv1d are installed? Maybe set config.use_cache=False during training at least. Otherwise, I'm out of quick options. Mamba2 is also more suitable for bigger architectures and longer sequences so it might be faster in comparison when using 1b params + 4k seq lengths for example.
  2. Can you pass safe_serialization=False when saving your model? Tbh, weight sharing is the standard in Mamba2 so you should not disable it. But, tied embeddings are a mixed thing; it might be beneficial or not.

vasqu avatar Sep 19 '24 09:09 vasqu

@vasqu do you perhaps know why the huggingface model breaks when you change the hidden_size in the config? Does it have to be a multiple of another parameter?

Jellymoon avatar Sep 20 '24 12:09 Jellymoon

@Jellymoon Only thing I'm aware of would be this line: hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim. expand is usually 2, so you need to consider that (2 * hidden_size) % head_dim == 0 .

Otherwise, this might also be relevant for the kernels: https://github.com/state-spaces/mamba/issues/352#issuecomment-2167093395.

And if you have a small reproducer or something, that would be ideal.

vasqu avatar Sep 20 '24 14:09 vasqu

@vasqu I tried changing the expanding factor and if I scale it relative to the hidden size it works for example:

hidden_size=2048, expand=4
hidden_size=4096, expand=2
hidden_size=8192, expand=1

What does not work is changing the hidden_size without changing the expanding factor. Even if you half or double the default hidden_size, (2 * hidden_size) % head_dim == 0 should still be satisfied I think?

Here is what I have been trying on an RTX 3090, I was trying to make a smaller model:

import torch
import time
from transformers import AutoTokenizer,Mamba2Model,Mamba2Config, MambaConfig, MambaModel
from tqdm import tqdm

# config = MambaConfig(
config = Mamba2Config(
    vocab_size=len(tokenizer.vocab),
    state_size=128,
    num_heads = 128,
    head_dim = 64,

    num_hidden_layers=4,
    hidden_size = 2048,
    # expand=4
)

# model = MambaModel(config).to(torch.device("cuda"))
model = Mamba2Model(config).to(torch.device("cuda"))


t = time.process_time()
for i in tqdm(range(1000)):
    inputs = {'input_ids': torch.randint(0, len(tokenizer), (1,4096), device=torch.device("cuda"))}
    outputs = model(**inputs, labels=inputs["input_ids"])

elapsed_time = time.process_time() - t
print(elapsed_time)

Jellymoon avatar Sep 20 '24 15:09 Jellymoon

You also have to pass the correct number of heads, i.e.:

import torch
from transformers import AutoTokenizer, Mamba2Model, Mamba2Config

hidden_size = 2048
expand = 2
head_dim = 64
# this here is the key to pass too, not very nice to have it passed in the config imo but that's the current state now :(
num_heads = (hidden_size * expand) // head_dim

config = Mamba2Config(
    vocab_size=len(tokenizer.vocab),
    state_size=128,
    num_heads=num_heads,
    head_dim=head_dim,
    num_hidden_layers=4,
    hidden_size=hidden_size,
    expand=expand,
    # I'd suggest using this as it's default for all models in the mamba-ssm repo
    n_groups=1,
)

...

vasqu avatar Sep 20 '24 16:09 vasqu

The scaling worked because the effective size didn't change @Jellymoon

vasqu avatar Sep 20 '24 16:09 vasqu

Thank you for the help!

Jellymoon avatar Sep 20 '24 16:09 Jellymoon

I had a couple of questions regarding the model's default settings in the Huggingface configuration:

From the paper: image

Is the no_bias_terms parameter set to True by default in the Huggingface version of Mamba2? Does the configuration use RMSNorm instead of LayerNorm by default, or is there a specific way to set this manually?

I am trying to find the best hyperparameters that you've run experiments on the top of AutoConfig.from_pretrained('AntonV/mamba2-130m-hf').

npkanaka avatar Oct 07 '24 22:10 npkanaka

@npkanaka There are two bias parameters:

  • Bias for the linear projections which is set to False by default ( https://github.com/huggingface/transformers/blob/d6ba1ac041ac0b07bc589dd82a67cfb76f75d0f9/src/transformers/models/mamba2/configuration_mamba2.py#L62 )
  • Bias for the convolution which is set to True by default ( https://github.com/huggingface/transformers/blob/d6ba1ac041ac0b07bc589dd82a67cfb76f75d0f9/src/transformers/models/mamba2/configuration_mamba2.py#L64 )

There is a parameter that suggests you could use layernorm instead of RMS norm in hf but it's sadly false / misleading. RMS norm will always be applied and tbh, it shouldn't matter too much.

vasqu avatar Oct 08 '24 00:10 vasqu

No problem.

  1. This seems a bit unlikely tbh. Have you ensured that mamba-ssm and causal-conv1d are installed? Maybe set config.use_cache=False during training at least. Otherwise, I'm out of quick options. Mamba2 is also more suitable for bigger architectures and longer sequences so it might be faster in comparison when using 1b params + 4k seq lengths for example.
  2. Can you pass safe_serialization=False when saving your model? Tbh, weight sharing is the standard in Mamba2 so you should not disable it. But, tied embeddings are a mixed thing; it might be beneficial or not.

why set use_cache=False could speed up training though? I have encounter the same question with max_length=100

Lynnzake avatar Dec 30 '24 08:12 Lynnzake

@Lynnzake At least in hf, you guarantee that inference is avoided 100% and in that case the code opts for the fused path, i.e. a kernel for conv combined with the mamba op. Anything "kernelized" is faster than the base ops one after the other (even if they are one kernel each).

Not sure what you mean with max_len but the mamba2 kernel at least is pretty unoptimized in regards to smaller seq lens. The speed only gets pretty good at lengths 2048<=

vasqu avatar Dec 30 '24 08:12 vasqu

@Lynnzake At least in hf, you guarantee that inference is avoided 100% and in that case the code opts for the fused path, i.e. a kernel for conv combined with the mamba op. Anything "kernelized" is faster than the base ops one after the other (even if they are one kernel each).

Not sure what you mean with max_len but the mamba2 kernel at least is pretty unoptimized in regards to smaller seq lens. The speed only gets pretty good at lengths 2048<=

Appreciate your reply, max_len is the inputs length on my models.

Lynnzake avatar Dec 30 '24 09:12 Lynnzake

Then the second part of my reply should be correct (?), i.e. max len should be 2048 or longer to be good on speed.

vasqu avatar Dec 30 '24 09:12 vasqu

Then the second part of my reply should be correct (?), i.e. max len should be 2048 or longer to be good on speed.

Yes.

Lynnzake avatar Dec 30 '24 09:12 Lynnzake