mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Mamba2 doesn't support Multi-GPU training

Open NadavSc opened this issue 11 months ago • 2 comments

Hi! I'm using SFTTrainer (inherited from Transformers Trainer) to fine-tune Mamba2. When using cuda_kernels_forward in Mamba2 on multiple GPUs the following error appears (full traceback in the end):

config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
TypeError: 'NoneType' object is not a mapping ```

However, it works just fine when I'm using the slower path, torch_forward. Do you know how to address this issue? Thanks a lot.

Reproduction

  from datasets import load_dataset
  from trl import SFTTrainer, SFTConfig
  from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
  
  model_id = 'AntonV/mamba2-130m-hf'
  dataset_name = 'yelp_review_full'

  tokenizer = AutoTokenizer.from_pretrained(model_id)
  tokenizer.pad_token = tokenizer.eos_token
  tokenizer.padding_side = 'right'
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
  model = AutoModelForCausalLM.from_pretrained(model_id)

  dataset = load_dataset(dataset_name, split='train', streaming=True)
  train_dataset = dataset

  training_args = SFTConfig(
      output_dir='./outputs',
      num_train_epochs=5,
      per_device_train_batch_size=4,
      per_device_eval_batch_size=4,
      logging_dir='./logs',
      learning_rate=2e-3,
      save_steps=500,
      save_safetensors=False,
      max_steps=10000,
      report_to='none'
  )
  trainer = SFTTrainer(
      model=model,
      processing_class=tokenizer,
      data_collator=data_collator,
      args=training_args,
      train_dataset=train_dataset,
  )
  trainer.train()

Traceback

  File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/train.py", line 82, in <module>
    main()
  File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/train.py", line 79, in main
    trainer.train()
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/transformers/trainer.py", line 3612, in training_step
    self.accelerator.backward(loss, **kwargs)
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/accelerate/accelerator.py", line 2248, in backward
    loss.backward(**kwargs)
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/autograd/function.py", line 306, in apply
    return user_fn(self, *args)
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 501, in decorate_bwd
    return bwd(*args, **kwargs)
  File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/mamba_ssm/ops/triton/ssd_combined.py", line 893, in backward
    dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
  File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/mamba_ssm/ops/triton/ssd_combined.py", line 414, in _mamba_chunk_scan_combined_bwd
    dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)
  File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/mamba_ssm/ops/triton/ssd_combined.py", line 250, in _chunk_scan_chunk_state_bwd_dx
    _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 170, in run
    config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
TypeError: 'NoneType' object is not a mapping

NadavSc avatar Jan 18 '25 11:01 NadavSc

Maybe related to https://github.com/state-spaces/mamba/issues/84

vasqu avatar Jan 20 '25 21:01 vasqu

same error

RxqJkb avatar Apr 08 '25 04:04 RxqJkb