mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Training Script

Open loretoparisi opened this issue 1 year ago • 7 comments
trafficstars

It would be worth to provide a train script, in order to train larger models (for instance 7B, 13B).

loretoparisi avatar Dec 04 '23 18:12 loretoparisi

You can use whichever training script / library you'd like, e.g. Megatron, DeepSpeed, lightning, hf accelerate etc. Just have to replace the model definition.

Examples: Lightning has lit-gpt: https://github.com/Lightning-AI/lit-gpt FlashAttention has training code, you can swap the model: https://github.com/Dao-AILab/flash-attention/tree/main/training

tridao avatar Dec 04 '23 22:12 tridao

does not seem to be so straightforward with HF trainer, quite literally:

  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2725, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2748, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 680, in forward
    return model_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 668, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
TypeError: MambaLMHeadModel.forward() got an unexpected keyword argument 'labels'

no labels in forward() ?

would be very nice if you could provide a simple, minimal example of how to use the models with HF trainer. thank you!

geronimi73 avatar Dec 05 '23 16:12 geronimi73

We've managed to train mamba by modifying the Huggingface Trainer class. Here is our implementation, we were actually able to train a chat model that seems to perform quite well.

justusmattern27 avatar Dec 07 '23 00:12 justusmattern27

We've managed to train mamba by modifying the Huggingface Trainer class. Here is our implementation, we were actually able to train a chat model that seems to perform quite well.

Cool, nice work! Are you using fp32 for this finetuning work?

binxuan avatar Dec 07 '23 02:12 binxuan

Hmm... Doesn't seem to work out of the box with lit-gpt.

Minimal example:

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.utils.hf import load_config_hf
from mamba_ssm.models.config_mamba import MambaConfig

with fabric.init_module(
        empty_init=isinstance(fabric.strategy, DeepSpeedStrategy)
    ):
  config = load_config_hf('state-spaces/mamba-2.8b')
  model = MambaLMHeadModel(MambaConfig(**config))

This will give the following error

Traceback (most recent call last):
  File "/home/me/lit-gpt/pretrain/mamba.py", line 782, in <module>
    setup(run_config)
  File "/home/me/lit-gpt/pretrain/mamba.py", line 306, in setup
    main(fabric, run_config)
  File "/home/me/lit-gpt/pretrain/mamba.py", line 342, in main
    model = MambaLMHeadModel(MambaConfig(**model_config))
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 199, in __init__
    self.backbone = MixerModel(
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 118, in __init__
    [
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 119, in <listcomp>
    create_block(
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 42, in create_block
    block = Block(
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 316, in __init__
    self.mixer = mixer_cls(dim)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 99, in __init__
    self.dt_proj.bias.copy_(inv_dt)
RuntimeError: The size of tensor a (0) must match the size of tensor b (5120) at non-singleton dimension 0

Calvinnncy97 avatar Dec 20 '23 09:12 Calvinnncy97

if that's true, how the heck am I to pass attention?

thistleknot avatar Dec 21 '23 20:12 thistleknot

Stage 2 works, but not stage 3. I don't have a fix at the moment. Problem is this line https://github.com/state-spaces/mamba/blob/eb2f7a520dd5e2949b7ae1c3ef44f6cb99faef5c/mamba_ssm/modules/mamba_simple.py#L98

Calvinnncy97 avatar Dec 22 '23 07:12 Calvinnncy97