mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Training a text classifier on Mamba

Open mdabedr opened this issue 1 year ago • 7 comments
trafficstars

Hello. Can you please some insights on how one can train a text classifier with Mamba

mdabedr avatar Feb 05 '24 18:02 mdabedr

The same way you would do it with a Transformer. There might be two differences that come to mind:

  • If you don't need causality, you'd want to use a bidirectional version of Mamba. We have an preprint coming out soon, but it should be reasonable to just create two copies of the inner SSM within Mamba's block, run it both ways and sum
  • Transformers sometimes use a classification token. You could try this with Mamba too, or just try pooling all the output features before the classification head.

albertfgu avatar Feb 05 '24 18:02 albertfgu

Thank you for your reply. Is the MambaLMHeadModel defined in the mixer_seq_simple.py useful in this particular application?

mdabedr avatar Feb 05 '24 18:02 mdabedr

You'd probably want to write a MambaClassifierHeadModel that has a similar structure: a Mamba model backbone with a classifier head.

tridao avatar Feb 05 '24 18:02 tridao

Trying to do the same task here, but an error occurs:

~/anaconda3/envs/myenv/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py:136, in Mamba.forward(self, hidden_states, inference_params)
    132         return out
    134 # We do matmul and transpose BLH -> HBL at the same time
    135 xz = rearrange(
--> 136     self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
    137     "d (b l) -> b d l",
    138     l=seqlen,
    139 )
    140 if self.in_proj.bias is not None:
    141     xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x1 and 12x1920)

My input is (128, 15, 12), and the Mamba model is configured as

d_model=12,  # Model dimension d_model
d_state=16,   # SSM state expansion factor
d_conv=4,     # Local convolution width
expand=2      # Block expansion factor

any insights on this? thanks

yudizhangzyd avatar Feb 08 '24 18:02 yudizhangzyd

The inputs to the Mamba block probably aren't formatted correctly. Check the example in the README and double check all your shapes.

albertfgu avatar Feb 08 '24 18:02 albertfgu

Thats weird, my input data successfully run through LSTM/Transformers

yudizhangzyd avatar Feb 08 '24 21:02 yudizhangzyd

One way is to use pretrained models from hugging face together with transformers library

import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

num_labels = 2  # the number of labels

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m")

model.lm_head = torch.nn.Linear(model.config.d_model, num_labels)

From here, you fine-tune the resulting model on your classification task. This approach was used with great success on kaggle, where you can find more details: https://www.kaggle.com/competitions/llm-detect-ai-generated-text/discussion/470093

maksymdolgikh avatar Feb 15 '24 15:02 maksymdolgikh