mamba
mamba copied to clipboard
Can `MambaForCausalLM` be used directly for training instead of `AutoModelForCausalLM`?
Hello,
I'm currently working with the transformers
library to train a model on causal language modeling tasks using the MambaForCausalLM
class. However, I've noticed that the typical approach to training in the library uses AutoModelForCausalLM
to load the model for training, and I'm wondering if it's possible and recommended to use MambaForCausalLM
directly for training instead.
Here is the code snippet I'm referring to:
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
In inference, I successfully use MambaForCausalLM
as follows:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
Could you clarify if using MambaForCausalLM
for training is supported and if there are any additional configurations required for this?
Thank you for your assistance.