Implement MambaForSequenceClassification
What does this PR do?
Adds the MambaForSequenceClassification model based on MambaModel backbone.
We recently published EHRMamba, a state-of-the-art foundation model for Electronic Health Records. This model is built on the same architecture and we will release the trained weights using the MambaForSequenceClassification class. https://vectorinstitute.github.io/EHRMamba
Fixes #30431
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [X] Did you read the contributor guideline, Pull Request section?
- [X] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case. https://github.com/huggingface/transformers/issues/30431
- [X] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [X] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
As discussed in https://github.com/huggingface/transformers/issues/30431, @ArthurZucker could you take a look? 😊
Notes
This implementation closely follows the GPT2ForSequenceClassification method, with the exception of pooling the last hidden states before passing them to the classifier to improve efficiency.
Referring to https://github.com/huggingface/transformers/pull/29552, "there's a test specific to sequence classification that expects all the unfrozen params to be initialized in the range [0.0, 1.0] and the initialized values for the mixer don't satisfy that assertion."
This results in a test failure even though the classifier head is initialized properly.
Could you rebase on main and make sure the CIs are green! 🤗
Could you rebase on main and make sure the CIs are green! 🤗
Of course! It should be good to merge now. There is a failed test for "MobileViTV2ModelTest" or similar which are unrelated to Mamba.
@ArthurZucker I did some digging on prior decoder model for classification implementations and realized some of them (e.g. gpt2) use caching. It seems the use case is when you want to do classification at different milestones within the sequence, so by caching, you don't need to restart decoding from beginning and can make this process more efficient. It's a rare use case but worth to have.
Okay!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@ArthurZucker I merged new main HuggingFace into my branch. Should be ready to merge.
@ArthurZucker I would love a review!
hey any updates on this PR?
For finetuning
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
The docs show only the PEFT fine-tuning approach. Does this also support full fine-tuning using Trainer?
Thanks!
For finetuning
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
The docs show only the PEFT fine-tuning approach. Does this also support full fine-tuning using Trainer?
Thanks!
Yes, it should support it.
hey any updates on this PR?
This should merge soon after @ArthurZucker final review.
@amyeroberts I was wondering if you might have any updates on @ArthurZucker? I haven't heard from him in about a month and just wanted to check if everything is alright.
Hey sorry, slipped through my notifications
@ArthurZucker Makes sense, I changed the classification head to a linear layer.
@ArthurZucker Pending review!
@ArthurZucker Now that the https://github.com/huggingface/transformers/pull/32080 is merged, can we do a final review for this one too? Also, I would like to add Mamba2ForSequenceClassification to this PR as well so we have both Mamba models with classification capabilities. Then I would be able to release the EHRMamba model on HuggingFace.
Hey, any update on this?
@Adibvafa have you tried running this? Whenever the model gets to an evaluation step I get the error below. The code I tried was the huggingface sequence classification tutorial (link to tutorial) but i used a gpu, replaced "distilbert/distilbert-base-uncased" with "state-spaces/mamba-130m-hf" and i replaced my local modeling_mamba.py with yours so it does load the model.
ERROR:
Traceback (most recent call last):
File ".../tutorial_script.py", line 71, in <module>
File ".../lib/python3.11/site-packages/transformers/trainer.py", line 3754, in predict
output = eval_loop(
^^^^^^^^^^
File ".../lib/python3.11/site-packages/transformers/trainer.py", line 3887, in evaluation_loop
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/accelerate/accelerator.py", line 2508, in pad_across_processes
return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 411, in wrapper
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 678, in pad_across_processes
return recursively_apply(
^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 107, in recursively_apply
return honor_type(
^^^^^^^^^^^
File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 81, in honor_type
return type(obj)(generator)
^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 110, in <genexpr>
recursively_apply(
File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 128, in recursively_apply
raise TypeError(
TypeError: Unsupported types (<class 'transformers.cache_utils.MambaCache'>) passed to `_pad_across_processes`.
Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` should be passed.
I install the your forked version of huggingface repo %pip install -qq git+https://github.com/Adibvafa/MambaForSequenceClassification.git. I'm trying to fine-tune on a text-classification dataset. I get the error similar to @Jellymoon. I'm trying to understand your code. Can you help what might be the issue here @Adibvafa ?
mlflow.set_experiment(MLFLOW_EXPERIMENT)
MODEL_OUTPUT_DIR = f"{TRAINER_DIR}/{DATASET_BASE_NAME}/{ts}"
training_args = TrainingArguments(
output_dir=MODEL_OUTPUT_DIR,
overwrite_output_dir=True,
learning_rate=LEARNING_RATE,
per_device_train_batch_size=8,
per_device_eval_batch_size=BATCH_SIZE,
num_train_epochs=EPOCHS,
weight_decay=WEIGHT_DECAY,
label_smoothing_factor=LABEL_SMOOTHING, # this causes model to be overconfident. keep non-zero value.
# bf16=True,
eval_strategy="epoch",
save_strategy="epoch",
group_by_length=True, # smart batching. useful for dynamic padding.
load_best_model_at_end=True,
disable_tqdm=False,
save_total_limit=3,
metric_for_best_model='f1', # must be on of those defined in `compute_metrics` func
greater_is_better=True,
# auto_find_batch_size=True, # will only reduce, might not be optimal.
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_data["train"].select(range(50)),
eval_dataset=tokenized_data["valid"],
tokenizer=tokenizer,
data_collator=data_collator, # aids in dynamic padding.
compute_metrics=compute_metrics,
)
# trainer.add_callback(EvalTrainDataCallback(trainer)) # had to remove because it was messing with early stopping.
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.001))
Std err/output:
[2024-08-25 06:37:20,144] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3
[WARNING] using untested triton version (2.3.1), only 1.0.0 is known to be compatible
[rank0]:[W reducer.cpp:1389] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
TypeError: Unsupported types (<class 'transformers.cache_utils.MambaCache'>) passed to `_pad_across_processes`. Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` should be passed.
File <command-3628464069596680>, line 10
7 mlflow.set_tags(tags)
9 # Start Training
---> 10 trainer.train()
12 # Save trained model to local dir
13 PIPELINE_DIR = f'{MODEL_OUTPUT_DIR}/pipeline'
File /databricks/python/lib/python3.11/site-packages/mlflow/utils/autologging_utils/safety.py:456, in safe_patch.<locals>.safe_patch_function(*args, **kwargs)
441 if (
442 active_session_failed
443 or autologging_is_disabled(autologging_integration)
(...)
450 # warning behavior during original function execution, since autologging is being
451 # skipped
452 with set_non_mlflow_warnings_behavior_for_current_thread(
453 disable_warnings=False,
454 reroute_warnings=False,
455 ):
--> 456 return original(*args, **kwargs)
458 # Whether or not the original / underlying function has been called during the
459 # execution of patched code
460 original_has_been_called = False
File /databricks/python_shell/dbruntime/huggingface_patches/transformers.py:54, in _create_patch_function.<locals>.patched_fit_function(self, *args, **kwargs)
52 call_succeeded = False
53 try:
---> 54 model = original_method(self, *args, **kwargs)
55 call_succeeded = True
56 return model
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:1954, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1952 hf_hub_utils.enable_progress_bars()
1953 else:
-> 1954 return inner_training_loop(
1955 args=args,
1956 resume_from_checkpoint=resume_from_checkpoint,
1957 trial=trial,
1958 ignore_keys_for_eval=ignore_keys_for_eval,
1959 )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:2392, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2389 self.control.should_training_stop = True
2391 self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 2392 self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2394 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
2395 if is_torch_xla_available():
2396 # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:2820, in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2818 metrics = None
2819 if self.control.should_evaluate:
-> 2820 metrics = self._evaluate(trial, ignore_keys_for_eval)
2822 if self.control.should_save:
2823 self._save_checkpoint(model, trial, metrics=metrics)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:2777, in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
2776 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2777 metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2778 self._report_to_hp_search(trial, self.state.global_step, metrics)
2780 # Run delayed LR scheduler now that metrics are populated
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:3701, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
3698 start_time = time.time()
3700 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3701 output = eval_loop(
3702 eval_dataloader,
3703 description="Evaluation",
3704 # No point gathering the predictions if there are no metrics, otherwise we defer to
3705 # self.args.prediction_loss_only
3706 prediction_loss_only=True if self.compute_metrics is None else None,
3707 ignore_keys=ignore_keys,
3708 metric_key_prefix=metric_key_prefix,
3709 )
3711 total_batch_size = self.args.eval_batch_size * self.args.world_size
3712 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:3912, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
3910 labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
3911 if logits is not None:
-> 3912 logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
3913 if self.preprocess_logits_for_metrics is not None:
3914 logits = self.preprocess_logits_for_metrics(logits, labels)
File /databricks/python/lib/python3.11/site-packages/accelerate/accelerator.py:2482, in Accelerator.pad_across_processes(self, tensor, dim, pad_index, pad_first)
2449 def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):
2450 """
2451 Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
2452 they can safely be gathered.
(...)
2480 ```
2481 """
-> 2482 return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:414, in chained_operation.<locals>.wrapper(*args, **kwargs)
411 @wraps(function)
412 def wrapper(*args, **kwargs):
413 try:
--> 414 return function(*args, **kwargs)
415 except DistributedOperationException as e:
416 operation = f"{function.__module__}.{function.__name__}"
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:681, in pad_across_processes(tensor, dim, pad_index, pad_first)
678 new_tensor[indices] = tensor
679 return new_tensor
--> 681 return recursively_apply(
682 _pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first
683 )
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:107, in recursively_apply(func, data, test_type, error_on_other_type, *args, **kwargs)
85 """
86 Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
87
(...)
104 The same data structure as `data` with `func` applied to every object of type `main_type`.
105 """
106 if isinstance(data, (tuple, list)):
--> 107 return honor_type(
108 data,
109 (
110 recursively_apply(
111 func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
112 )
113 for o in data
114 ),
115 )
116 elif isinstance(data, Mapping):
117 return type(data)(
118 {
119 k: recursively_apply(
(...)
123 }
124 )
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:81, in honor_type(obj, generator)
79 return type(obj)(*list(generator))
80 else:
---> 81 return type(obj)(generator)
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:110, in <genexpr>(.0)
85 """
86 Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
87
(...)
104 The same data structure as `data` with `func` applied to every object of type `main_type`.
105 """
106 if isinstance(data, (tuple, list)):
107 return honor_type(
108 data,
109 (
--> 110 recursively_apply(
111 func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
112 )
113 for o in data
114 ),
115 )
116 elif isinstance(data, Mapping):
117 return type(data)(
118 {
119 k: recursively_apply(
(...)
123 }
124 )
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:128, in recursively_apply(func, data, test_type, error_on_other_type, *args, **kwargs)
126 return func(data, *args, **kwargs)
127 elif error_on_other_type:
--> 128 raise TypeError(
129 f"Unsupported types ({type(data)}) passed to `{func.__name__}`. Only nested list/tuple/dicts of "
130 f"objects that are valid for `{test_type.__name__}` should be passed."
131 )
132 return data
Hey @Jellymoon, the Mamba model, works as expected during the training loop. However, it fails during the evaluation loop. So, I found that it is necessary to set use_cache=False when loading the model so that evaluation does not fail.
cc: @Adibvafa
model = MambaForSequenceClassification.from_pretrained(
model_path,
num_labels=len(id2label),
id2label=id2label,
label2id=label2id,
use_cache=False # This needs to be passed when using eval and training Mamba for sequence classification otherwise it will raise an error
)
I noticed that the training speed (fine-tuning) is very slow compared to the other HF transformer models. Can something be improved here?
Hey @Jellymoon, the Mamba model, works as expected during the training loop. However, it fails during the evaluation loop. So, I found that it is necessary to set
use_cache=Falsewhen loading the model so that evaluation does not fail.cc: @Adibvafa
model = MambaForSequenceClassification.from_pretrained( model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id, use_cache=False # This needs to be passed when using eval and training Mamba for sequence classification otherwise it will raise an error )
I will take a look. Thank you for bringing this up! @Jellymoon @mohith7548
I noticed that the training speed (fine-tuning) is very slow compared to the other HF transformer models. Can something be improved here?
Do you have mamba-ssm installed? Is it slow in the classification or in Mamba in general?
@Adibvafa, I have mamba-ssm installed. However, I realized that it also need causal-conv1d>=1.4.0 package train faster. Otherwise it was showing some warning related to conv1d that it's gonna use slow/sequential version. Now that I installed causal-conv1d>=1.4.0 finetuning works as expected.
@Adibvafa, I have
mamba-ssminstalled. However, I realized that it also needcausal-conv1d>=1.4.0package train faster. Otherwise it was showing some warning related to conv1d that it's gonna use slow/sequential version. Now that I installedcausal-conv1d>=1.4.0finetuning works as expected.
Amazing! There is currently a bug with the slow training path that either breaks in low precision training or uses a huge amount of memory at once. I suggest opening the issue for the memory surge. I have opened the issue and currently working on the low precision training error.
@Adibvafa, a bug in Mamba? or transformers? Can you eloborate? Please share the link of the issue.
I successfully ran the Mamba model with the new changes you made to the code. Any chance that this will also support the Mamba2 model?
There is currently a bug with the slow training path that either breaks in low precision training or uses a huge amount of memory at once. I suggest opening the issue for the memory surge. I have opened the issue and currently working on the low precision training error.
I think the low-precision bug possibly refers to #32691. The huge amount of memory in the slow path is to be expected though and is one of the reasons why the kernel exists (i.e. to avoid materializing certain tensors etc). Nothing you can really do about this tbh. cc @mohith7548