transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Implement MambaForSequenceClassification

Open Adibvafa opened this issue 1 year ago • 24 comments

What does this PR do?

Adds the MambaForSequenceClassification model based on MambaModel backbone.

We recently published EHRMamba, a state-of-the-art foundation model for Electronic Health Records. This model is built on the same architecture and we will release the trained weights using the MambaForSequenceClassification class. https://vectorinstitute.github.io/EHRMamba

Fixes #30431

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [X] Did you read the contributor guideline, Pull Request section?
  • [X] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case. https://github.com/huggingface/transformers/issues/30431
  • [X] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [X] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

As discussed in https://github.com/huggingface/transformers/issues/30431, @ArthurZucker could you take a look? 😊

Notes

This implementation closely follows the GPT2ForSequenceClassification method, with the exception of pooling the last hidden states before passing them to the classifier to improve efficiency.

Adibvafa avatar May 31 '24 02:05 Adibvafa

Referring to https://github.com/huggingface/transformers/pull/29552, "there's a test specific to sequence classification that expects all the unfrozen params to be initialized in the range [0.0, 1.0] and the initialized values for the mixer don't satisfy that assertion."

This results in a test failure even though the classifier head is initialized properly.

Adibvafa avatar May 31 '24 03:05 Adibvafa

Could you rebase on main and make sure the CIs are green! 🤗

ArthurZucker avatar Jun 06 '24 07:06 ArthurZucker

Could you rebase on main and make sure the CIs are green! 🤗

Of course! It should be good to merge now. There is a failed test for "MobileViTV2ModelTest" or similar which are unrelated to Mamba.

Adibvafa avatar Jun 07 '24 02:06 Adibvafa

@ArthurZucker I did some digging on prior decoder model for classification implementations and realized some of them (e.g. gpt2) use caching. It seems the use case is when you want to do classification at different milestones within the sequence, so by caching, you don't need to restart decoding from beginning and can make this process more efficient. It's a rare use case but worth to have.

Adibvafa avatar Jun 11 '24 16:06 Adibvafa

Okay!

ArthurZucker avatar Jun 19 '24 07:06 ArthurZucker

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker I merged new main HuggingFace into my branch. Should be ready to merge.

Adibvafa avatar Jun 26 '24 15:06 Adibvafa

@ArthurZucker I would love a review!

Adibvafa avatar Jul 08 '24 14:07 Adibvafa

hey any updates on this PR?

mohith7548 avatar Jul 14 '24 11:07 mohith7548

For finetuning

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

The docs show only the PEFT fine-tuning approach. Does this also support full fine-tuning using Trainer?

Thanks!

mohith7548 avatar Jul 14 '24 11:07 mohith7548

For finetuning

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

The docs show only the PEFT fine-tuning approach. Does this also support full fine-tuning using Trainer?

Thanks!

Yes, it should support it.

Adibvafa avatar Jul 15 '24 15:07 Adibvafa

hey any updates on this PR?

This should merge soon after @ArthurZucker final review.

Adibvafa avatar Jul 15 '24 15:07 Adibvafa

@amyeroberts I was wondering if you might have any updates on @ArthurZucker? I haven't heard from him in about a month and just wanted to check if everything is alright.

Adibvafa avatar Jul 17 '24 17:07 Adibvafa

Hey sorry, slipped through my notifications

ArthurZucker avatar Jul 23 '24 08:07 ArthurZucker

@ArthurZucker Makes sense, I changed the classification head to a linear layer.

Adibvafa avatar Jul 27 '24 21:07 Adibvafa

@ArthurZucker Pending review!

Adibvafa avatar Aug 02 '24 13:08 Adibvafa

@ArthurZucker Now that the https://github.com/huggingface/transformers/pull/32080 is merged, can we do a final review for this one too? Also, I would like to add Mamba2ForSequenceClassification to this PR as well so we have both Mamba models with classification capabilities. Then I would be able to release the EHRMamba model on HuggingFace.

Adibvafa avatar Aug 09 '24 19:08 Adibvafa

Hey, any update on this?

mohith7548 avatar Aug 14 '24 07:08 mohith7548

@Adibvafa have you tried running this? Whenever the model gets to an evaluation step I get the error below. The code I tried was the huggingface sequence classification tutorial (link to tutorial) but i used a gpu, replaced "distilbert/distilbert-base-uncased" with "state-spaces/mamba-130m-hf" and i replaced my local modeling_mamba.py with yours so it does load the model.

ERROR:

Traceback (most recent call last):
  File ".../tutorial_script.py", line 71, in <module>
  File ".../lib/python3.11/site-packages/transformers/trainer.py", line 3754, in predict
    output = eval_loop(
             ^^^^^^^^^^
  File ".../lib/python3.11/site-packages/transformers/trainer.py", line 3887, in evaluation_loop
    logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/accelerator.py", line 2508, in pad_across_processes
    return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 411, in wrapper
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 678, in pad_across_processes
    return recursively_apply(
           ^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 107, in recursively_apply
    return honor_type(
           ^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 81, in honor_type
    return type(obj)(generator)
           ^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 110, in <genexpr>
    recursively_apply(
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 128, in recursively_apply
    raise TypeError(
TypeError: Unsupported types (<class 'transformers.cache_utils.MambaCache'>) passed to `_pad_across_processes`. 
Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` should be passed.

Jellymoon avatar Aug 15 '24 19:08 Jellymoon

I install the your forked version of huggingface repo %pip install -qq git+https://github.com/Adibvafa/MambaForSequenceClassification.git. I'm trying to fine-tune on a text-classification dataset. I get the error similar to @Jellymoon. I'm trying to understand your code. Can you help what might be the issue here @Adibvafa ?

mlflow.set_experiment(MLFLOW_EXPERIMENT)
MODEL_OUTPUT_DIR = f"{TRAINER_DIR}/{DATASET_BASE_NAME}/{ts}"

training_args = TrainingArguments(
    output_dir=MODEL_OUTPUT_DIR,
    overwrite_output_dir=True,
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=WEIGHT_DECAY,
    label_smoothing_factor=LABEL_SMOOTHING, # this causes model to be overconfident. keep non-zero value.
    # bf16=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    group_by_length=True, # smart batching. useful for dynamic padding.
    load_best_model_at_end=True,
    disable_tqdm=False,
    save_total_limit=3,
    metric_for_best_model='f1', # must be on of those defined in `compute_metrics` func
    greater_is_better=True,
    # auto_find_batch_size=True, # will only reduce, might not be optimal.
)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"].select(range(50)),
    eval_dataset=tokenized_data["valid"],
    tokenizer=tokenizer,
    data_collator=data_collator, # aids in dynamic padding.
    compute_metrics=compute_metrics,
)

# trainer.add_callback(EvalTrainDataCallback(trainer)) # had to remove because it was messing with early stopping.
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.001)) 

Std err/output:

[2024-08-25 06:37:20,144] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.3
 [WARNING]  using untested triton version (2.3.1), only 1.0.0 is known to be compatible
[rank0]:[W reducer.cpp:1389] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())

TypeError: Unsupported types (<class 'transformers.cache_utils.MambaCache'>) passed to `_pad_across_processes`. Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` should be passed.
File <command-3628464069596680>, line 10
      7 mlflow.set_tags(tags)
      9 # Start Training
---> 10 trainer.train()
     12 # Save trained model to local dir
     13 PIPELINE_DIR = f'{MODEL_OUTPUT_DIR}/pipeline'
File /databricks/python/lib/python3.11/site-packages/mlflow/utils/autologging_utils/safety.py:456, in safe_patch.<locals>.safe_patch_function(*args, **kwargs)
    441 if (
    442     active_session_failed
    443     or autologging_is_disabled(autologging_integration)
   (...)
    450     # warning behavior during original function execution, since autologging is being
    451     # skipped
    452     with set_non_mlflow_warnings_behavior_for_current_thread(
    453         disable_warnings=False,
    454         reroute_warnings=False,
    455     ):
--> 456         return original(*args, **kwargs)
    458 # Whether or not the original / underlying function has been called during the
    459 # execution of patched code
    460 original_has_been_called = False
File /databricks/python_shell/dbruntime/huggingface_patches/transformers.py:54, in _create_patch_function.<locals>.patched_fit_function(self, *args, **kwargs)
     52 call_succeeded = False
     53 try:
---> 54     model = original_method(self, *args, **kwargs)
     55     call_succeeded = True
     56     return model
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:1954, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1952         hf_hub_utils.enable_progress_bars()
   1953 else:
-> 1954     return inner_training_loop(
   1955         args=args,
   1956         resume_from_checkpoint=resume_from_checkpoint,
   1957         trial=trial,
   1958         ignore_keys_for_eval=ignore_keys_for_eval,
   1959     )
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:2392, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2389     self.control.should_training_stop = True
   2391 self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 2392 self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2394 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
   2395     if is_torch_xla_available():
   2396         # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:2820, in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2818 metrics = None
   2819 if self.control.should_evaluate:
-> 2820     metrics = self._evaluate(trial, ignore_keys_for_eval)
   2822 if self.control.should_save:
   2823     self._save_checkpoint(model, trial, metrics=metrics)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:2777, in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
   2776 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2777     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2778     self._report_to_hp_search(trial, self.state.global_step, metrics)
   2780     # Run delayed LR scheduler now that metrics are populated
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:3701, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3698 start_time = time.time()
   3700 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3701 output = eval_loop(
   3702     eval_dataloader,
   3703     description="Evaluation",
   3704     # No point gathering the predictions if there are no metrics, otherwise we defer to
   3705     # self.args.prediction_loss_only
   3706     prediction_loss_only=True if self.compute_metrics is None else None,
   3707     ignore_keys=ignore_keys,
   3708     metric_key_prefix=metric_key_prefix,
   3709 )
   3711 total_batch_size = self.args.eval_batch_size * self.args.world_size
   3712 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-266acfb6-094d-4adf-baef-dac3d0ca7200/lib/python3.11/site-packages/transformers/trainer.py:3912, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   3910     labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
   3911 if logits is not None:
-> 3912     logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
   3913     if self.preprocess_logits_for_metrics is not None:
   3914         logits = self.preprocess_logits_for_metrics(logits, labels)
File /databricks/python/lib/python3.11/site-packages/accelerate/accelerator.py:2482, in Accelerator.pad_across_processes(self, tensor, dim, pad_index, pad_first)
   2449 def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):
   2450     """
   2451     Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
   2452     they can safely be gathered.
   (...)
   2480     ```
   2481     """
-> 2482     return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:414, in chained_operation.<locals>.wrapper(*args, **kwargs)
    411 @wraps(function)
    412 def wrapper(*args, **kwargs):
    413     try:
--> 414         return function(*args, **kwargs)
    415     except DistributedOperationException as e:
    416         operation = f"{function.__module__}.{function.__name__}"
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:681, in pad_across_processes(tensor, dim, pad_index, pad_first)
    678     new_tensor[indices] = tensor
    679     return new_tensor
--> 681 return recursively_apply(
    682     _pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first
    683 )
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:107, in recursively_apply(func, data, test_type, error_on_other_type, *args, **kwargs)
     85 """
     86 Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
     87 
   (...)
    104     The same data structure as `data` with `func` applied to every object of type `main_type`.
    105 """
    106 if isinstance(data, (tuple, list)):
--> 107     return honor_type(
    108         data,
    109         (
    110             recursively_apply(
    111                 func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
    112             )
    113             for o in data
    114         ),
    115     )
    116 elif isinstance(data, Mapping):
    117     return type(data)(
    118         {
    119             k: recursively_apply(
   (...)
    123         }
    124     )
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:81, in honor_type(obj, generator)
     79     return type(obj)(*list(generator))
     80 else:
---> 81     return type(obj)(generator)
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:110, in <genexpr>(.0)
     85 """
     86 Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
     87 
   (...)
    104     The same data structure as `data` with `func` applied to every object of type `main_type`.
    105 """
    106 if isinstance(data, (tuple, list)):
    107     return honor_type(
    108         data,
    109         (
--> 110             recursively_apply(
    111                 func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
    112             )
    113             for o in data
    114         ),
    115     )
    116 elif isinstance(data, Mapping):
    117     return type(data)(
    118         {
    119             k: recursively_apply(
   (...)
    123         }
    124     )
File /databricks/python/lib/python3.11/site-packages/accelerate/utils/operations.py:128, in recursively_apply(func, data, test_type, error_on_other_type, *args, **kwargs)
    126     return func(data, *args, **kwargs)
    127 elif error_on_other_type:
--> 128     raise TypeError(
    129         f"Unsupported types ({type(data)}) passed to `{func.__name__}`. Only nested list/tuple/dicts of "
    130         f"objects that are valid for `{test_type.__name__}` should be passed."
    131     )
    132 return data

mohith7548 avatar Aug 25 '24 09:08 mohith7548

Hey @Jellymoon, the Mamba model, works as expected during the training loop. However, it fails during the evaluation loop. So, I found that it is necessary to set use_cache=False when loading the model so that evaluation does not fail.

cc: @Adibvafa

model = MambaForSequenceClassification.from_pretrained(
    model_path, 
    num_labels=len(id2label), 
    id2label=id2label, 
    label2id=label2id,
    use_cache=False  # This needs to be passed when using eval and training Mamba for sequence classification otherwise it will raise an error
)

mohith7548 avatar Aug 25 '24 13:08 mohith7548

I noticed that the training speed (fine-tuning) is very slow compared to the other HF transformer models. Can something be improved here?

mohith7548 avatar Aug 25 '24 13:08 mohith7548

Hey @Jellymoon, the Mamba model, works as expected during the training loop. However, it fails during the evaluation loop. So, I found that it is necessary to set use_cache=False when loading the model so that evaluation does not fail.

cc: @Adibvafa

model = MambaForSequenceClassification.from_pretrained(
    model_path, 
    num_labels=len(id2label), 
    id2label=id2label, 
    label2id=label2id,
    use_cache=False  # This needs to be passed when using eval and training Mamba for sequence classification otherwise it will raise an error
)

I will take a look. Thank you for bringing this up! @Jellymoon @mohith7548

Adibvafa avatar Aug 27 '24 14:08 Adibvafa

I noticed that the training speed (fine-tuning) is very slow compared to the other HF transformer models. Can something be improved here?

Do you have mamba-ssm installed? Is it slow in the classification or in Mamba in general?

Adibvafa avatar Aug 27 '24 14:08 Adibvafa

@Adibvafa, I have mamba-ssm installed. However, I realized that it also need causal-conv1d>=1.4.0 package train faster. Otherwise it was showing some warning related to conv1d that it's gonna use slow/sequential version. Now that I installed causal-conv1d>=1.4.0 finetuning works as expected.

mohith7548 avatar Aug 28 '24 15:08 mohith7548

@Adibvafa, I have mamba-ssm installed. However, I realized that it also need causal-conv1d>=1.4.0 package train faster. Otherwise it was showing some warning related to conv1d that it's gonna use slow/sequential version. Now that I installed causal-conv1d>=1.4.0 finetuning works as expected.

Amazing! There is currently a bug with the slow training path that either breaks in low precision training or uses a huge amount of memory at once. I suggest opening the issue for the memory surge. I have opened the issue and currently working on the low precision training error.

Adibvafa avatar Aug 28 '24 15:08 Adibvafa

@Adibvafa, a bug in Mamba? or transformers? Can you eloborate? Please share the link of the issue.

mohith7548 avatar Aug 28 '24 15:08 mohith7548

I successfully ran the Mamba model with the new changes you made to the code. Any chance that this will also support the Mamba2 model?

mohith7548 avatar Aug 28 '24 18:08 mohith7548

There is currently a bug with the slow training path that either breaks in low precision training or uses a huge amount of memory at once. I suggest opening the issue for the memory surge. I have opened the issue and currently working on the low precision training error.

I think the low-precision bug possibly refers to #32691. The huge amount of memory in the slow path is to be expected though and is one of the reasons why the kernel exists (i.e. to avoid materializing certain tensors etc). Nothing you can really do about this tbh. cc @mohith7548

vasqu avatar Sep 05 '24 18:09 vasqu