mamba
mamba copied to clipboard
Mamba2 doesn't support Multi-GPU training
Hi! I'm using SFTTrainer (inherited from Transformers Trainer) to fine-tune Mamba2. When using cuda_kernels_forward in Mamba2 on multiple GPUs the following error appears (full traceback in the end):
config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
TypeError: 'NoneType' object is not a mapping ```
However, it works just fine when I'm using the slower path, torch_forward. Do you know how to address this issue? Thanks a lot.
Reproduction
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
model_id = 'AntonV/mamba2-130m-hf'
dataset_name = 'yelp_review_full'
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset(dataset_name, split='train', streaming=True)
train_dataset = dataset
training_args = SFTConfig(
output_dir='./outputs',
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
logging_dir='./logs',
learning_rate=2e-3,
save_steps=500,
save_safetensors=False,
max_steps=10000,
report_to='none'
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
data_collator=data_collator,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()
Traceback
File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/train.py", line 82, in <module>
main()
File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/train.py", line 79, in main
trainer.train()
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/transformers/trainer.py", line 2123, in train
return inner_training_loop(
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/transformers/trainer.py", line 3612, in training_step
self.accelerator.backward(loss, **kwargs)
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/accelerate/accelerator.py", line 2248, in backward
loss.backward(**kwargs)
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/_tensor.py", line 521, in backward
torch.autograd.backward(
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/autograd/__init__.py", line 289, in backward
_engine_run_backward(
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/autograd/function.py", line 306, in apply
return user_fn(self, *args)
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 501, in decorate_bwd
return bwd(*args, **kwargs)
File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/mamba_ssm/ops/triton/ssd_combined.py", line 893, in backward
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/mamba_ssm/ops/triton/ssd_combined.py", line 414, in _mamba_chunk_scan_combined_bwd
dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)
File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/mamba_ssm/ops/triton/ssd_combined.py", line 250, in _chunk_scan_chunk_state_bwd_dx
_chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 170, in run
config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
TypeError: 'NoneType' object is not a mapping
Maybe related to https://github.com/state-spaces/mamba/issues/84
same error