mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Can `MambaForCausalLM` be used directly for training instead of `AutoModelForCausalLM`?

Open LumenScopeAI opened this issue 10 months ago • 2 comments

Hello,

I'm currently working with the transformers library to train a model on causal language modeling tasks using the MambaForCausalLM class. However, I've noticed that the typical approach to training in the library uses AutoModelForCausalLM to load the model for training, and I'm wondering if it's possible and recommended to use MambaForCausalLM directly for training instead.

Here is the code snippet I'm referring to:

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config =  LoraConfig(
        r=8,
        target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="quote",
)
trainer.train()

In inference, I successfully use MambaForCausalLM as follows:

from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

Could you clarify if using MambaForCausalLM for training is supported and if there are any additional configurations required for this?

Thank you for your assistance.

LumenScopeAI avatar Apr 16 '24 08:04 LumenScopeAI