Improve Mamba Speed
Hello AnFreTh,
Thank you for your work on this project. I am currently using Mambular to process tabular data, but I am experiencing very slow training speeds. On average, each epoch is taking around 80 minutes to complete.
Here are the details of my setup:
- Batch size: 256
- GPU: NVIDIA 4090
- Data: 8139 samples, each with 425 features
- Model settings: default parameters
For comparison, when I use ResNet or FT-Transformer as tabular encoder with the same setup, the training speed is approximately 25 seconds per epoch, which is significantly faster. Is it expected that Mambular would be much slower than ResNet or FT Transformer? Or could this be an issue with my configuration or code?
I would appreciate any insight you could provide. Is there any known issue, or something I can adjust in my configuration to improve the speed?
Please let me know if you need additional information to help diagnose the problem.
Thank you for your time and assistance!
It is expected that Mambular is slower than e.g. FT-Transformer, especially for datasets with a lot of features, since training time increases linearly with sequence length (number of features). However, we experienced this by a factor of 2.5-3 while being more memory efficient than FT-Transformer.
Could you provide a minimal code example with simulated data where you experience similar training times? Then we can verify.
Hello AnFreTh,
Thank you for your reply. Based on your suggestion, I have prepared a minimal code example for you to review.
In my current framework, I am using Mambular as the tabular encoder within a table-image contrastive learning setting. I defined a CustomMambularEncoder class, making the following modifications:
- I used an embedding method consistent with FT-Transformer due to differences in how the data is read.
- I removed the classification head, so the encoder outputs only the feature vectors from Mambular.
For simplicity, the provided code example only uses simulated tabular data. This dataset has 8139 samples in total, with 6530 samples split between the training and validation sets. Each sample consists of 423 numerical features only, with no categorical features.
When running this simplified code (with a batch size of 16), training the Mambular encoder takes approximately 2.5 hours per epoch, while using the FT-Transformer encoder takes around 15 seconds per epoch, and using ResNet as the encoder takes about 7 seconds per epoch.
I have attached the code example for your review. Please let me know if anything else is needed to further investigate the issue.
Thank you again for your help!
import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from mambular.arch_utils.mamba_arch import Mamba
from mambular.arch_utils.normalization_layers import (
RMSNorm,
LayerNorm,
LearnableLayerScaling,
BatchNorm,
InstanceNorm,
GroupNorm,
)
from mambular.configs.mambular_config import DefaultMambularConfig
from mambular.base_models.basemodel import BaseModel
from typing import List
from torch import Tensor
# embedding methods from FT-Transformer
from models.rtdl_revisiting_models import LinearEmbeddings, _CLSEmbedding
class CustomMambularEncoder(BaseModel):
"""
Modified encoder based on Mambular:
- the embedding layer is modified to be consistent with FT-Transformer.
- the tabular head is removed.
"""
def __init__(
self,
n_cont_features: int,
cat_cardinalities: List[int],
n_categories: List[int],
config: DefaultMambularConfig = DefaultMambularConfig(),
**kwargs,
):
super().__init__(**kwargs)
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
self.lr = self.hparams.get("lr", config.lr)
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
self.pooling_method = self.hparams.get("pooling_method", config.pooling_method)
self.shuffle_embeddings = self.hparams.get(
"shuffle_embeddings", config.shuffle_embeddings
)
self.mamba = Mamba(
d_model=self.hparams.get("d_model", config.d_model),
n_layers=self.hparams.get("n_layers", config.n_layers),
expand_factor=self.hparams.get("expand_factor", config.expand_factor),
bias=self.hparams.get("bias", config.bias),
d_conv=self.hparams.get("d_conv", config.d_conv),
conv_bias=self.hparams.get("conv_bias", config.conv_bias),
dropout=self.hparams.get("dropout", config.dropout),
dt_rank=self.hparams.get("dt_rank", config.dt_rank),
d_state=self.hparams.get("d_state", config.d_state),
dt_scale=self.hparams.get("dt_scale", config.dt_scale),
dt_init=self.hparams.get("dt_init", config.dt_init),
dt_max=self.hparams.get("dt_max", config.dt_max),
dt_min=self.hparams.get("dt_min", config.dt_min),
dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor),
norm=globals()[self.hparams.get("norm", config.norm)],
activation=self.hparams.get("activation", config.activation),
bidirectional=self.hparams.get("bidiretional", config.bidirectional),
use_learnable_interaction=self.hparams.get(
"use_learnable_interactions", config.use_learnable_interaction
),
AD_weight_decay=self.hparams.get("AB_weight_decay", config.AD_weight_decay),
BC_layer_norm=self.hparams.get("AB_layer_norm", config.BC_layer_norm),
layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps),
)
norm_layer = self.hparams.get("norm", config.norm)
if norm_layer == "RMSNorm":
self.norm_f = RMSNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "LayerNorm":
self.norm_f = LayerNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "BatchNorm":
self.norm_f = BatchNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "InstanceNorm":
self.norm_f = InstanceNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "GroupNorm":
self.norm_f = GroupNorm(
1,
self.hparams.get("d_model", config.d_model),
eps=config.layer_norm_eps,
)
elif norm_layer == "LearnableLayerScaling":
self.norm_f = LearnableLayerScaling(
self.hparams.get("d_model", config.d_model)
)
else:
raise ValueError(f"Unsupported normalization layer: {norm_layer}")
# >>> Feature & cls embeddings in FT-Transformer.
self.cont_embeddings = (
LinearEmbeddings(n_cont_features+ len(cat_cardinalities), config.d_model) if n_cont_features > 0 else None
)
self.cls_embedding = _CLSEmbedding(config.d_model)
# <<<
if self.pooling_method == "cls":
self.use_cls = True
else:
self.use_cls = self.hparams.get("use_cls", config.use_cls)
if self.shuffle_embeddings:
self.perm = torch.randperm(self.embedding_layer.seq_len)
def forward(self, x):
# cls embedding
x_embeddings: List[Tensor] = []
if self.cls_embedding is not None:
x_embeddings.append(self.cls_embedding(x.shape[:-1]))
# feature embedding, only numerical features in this case
x_embeddings.append(self.cont_embeddings(x))
x = torch.cat(x_embeddings, dim=1)
if self.shuffle_embeddings:
x = x[:, self.perm, :]
x = self.mamba(x)
if self.pooling_method == "avg":
x = torch.mean(x, dim=1)
elif self.pooling_method == "max":
x, _ = torch.max(x, dim=1)
elif self.pooling_method == "sum":
x = torch.sum(x, dim=1)
elif self.pooling_method == "cls_token":
x = x[:, -1]
elif self.pooling_method == "last":
x = x[:, -1]
else:
raise ValueError(f"Invalid pooling method: {self.pooling_method}")
x = self.norm_f(x)
return x
class MinimalContrastiveMambularModel(pl.LightningModule):
"""
Contrastive model for tabular data.
"""
def __init__(self, feature_dim=128):
super().__init__()
self.mambular_encoder = CustomMambularEncoder(
n_cont_features=423,
cat_cardinalities=[],
n_categories=[]
)
self.projection_head = torch.nn.Sequential(
torch.nn.Linear(64, feature_dim),
torch.nn.ReLU(),
torch.nn.Linear(feature_dim, feature_dim)
)
def forward(self, x):
encoded = self.mambular_encoder(x)
projected = self.projection_head(encoded)
return F.normalize(projected, dim=1)
def training_step(self, batch, batch_idx):
x1, x2 = batch
z1 = self.forward(x1)
z2 = self.forward(x2)
loss = self.contrastive_loss(z1, z2)
self.log('train_loss', loss)
return loss
def contrastive_loss(self, z1, z2, temperature=0.5):
# NT-Xent Loss
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
batch_size = z1.shape[0]
similarity_matrix = torch.matmul(z1, z2.T) / temperature
labels = torch.arange(batch_size, device=z1.device)
loss = F.cross_entropy(similarity_matrix, labels)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-4)
# Simulated tabular data
# 6530 samples in the training & validation set
# 423 numerical features, 0 categorical features
simulated_data_1 = torch.rand(6530, 423)
simulated_data_2 = torch.rand(6530, 423)
# DataLoader
train_dataset = torch.utils.data.TensorDataset(simulated_data_1, simulated_data_2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16)
# training
model = MinimalContrastiveMambularModel()
trainer = pl.Trainer(max_epochs=5, gpus=1, limit_train_batches=1.0)
torch.cuda.empty_cache()
trainer.fit(model, train_loader)
I could not recreate the extreme differences you reported, but still using default Mambular was 10x slower than FTTransformer for this specific setup. We will update the current Mambablock implementation to increase speed.
- update mamba-arch with parallel scan for speed improvement.
- include LSTM for larger datasets. see #134
I could not recreate the extreme differences you reported, but still using default Mambular was 10x slower than FTTransformer for this specific setup. We will update the current Mambablock implementation to increase speed.
- update mamba-arch with parallel scan for speed improvement.
- include LSTM for larger datasets. see LSTM/GRU identity issue #134
Thank you for taking the time to investigate the issue. I will try these and look forward to your updates. Thanks again for your help and support!
If you experiment further you could -instead of the python mamba implementation from Mambular- try out the original Mamba implementation: https://pypi.org/project/mamba-ssm/ If you do so, please let us know whether it improves speed :).
Since a true fix made available in the package might take some time, there are two fixes you could try to solve the issue faster:
First, try the TabulaRNN(model_type="LSTM", d_conv=16) but from the develop branch. So install the package via:
pip install git+https://github.com/basf/mamba-tabular.git@develop
This should be more memory efficient and similar in speed to the FT-Transformer. I would advise increase the kernel size of the convolution, given your large number of variables.
Second, depending on your ressources, you cuold try to leverage the original mamba implementation. This can be tricky, since not all systems/GPUs are supported.
pip install mamba-ssm
Then simply import mamba and switch it with the pytorch version from Mambular.
from mamba_ssm import Mamba #You could also try Mamba2
# in your class switch out the Mamba from Mambular with
self.mamba = nn.ModuleList()
for _ in range(n_layers):
self.mamba.append(Mamba(
d_model=self.hparams.get("d_model", config.d_model),
expand_factor=self.hparams.get("expand_factor", config.expand_factor),
d_conv=self.hparams.get("d_conv", config.d_conv),
)
)
See: https://github.com/state-spaces/mamba for further details on the original implementation.
Since a true fix made available in the package might take some time, there are two fixes you could try to solve the issue faster:
First, try the TabulaRNN(model_type="LSTM", d_conv=16) but from the develop branch. So install the package via:
pip install git+https://github.com/basf/mamba-tabular.git@developThis should be more memory efficient and similar in speed to the FT-Transformer. I would advise increase the kernel size of the convolution, given your large number of variables.
Second, depending on your ressources, you cuold try to leverage the original mamba implementation. This can be tricky, since not all systems/GPUs are supported.
pip install mamba-ssmThen simply import mamba and switch it with the pytorch version from Mambular.
from mamba_ssm import Mamba #You could also try Mamba2 # in your class switch out the Mamba from Mambular with self.mamba = nn.ModuleList() for _ in range(n_layers): self.mamba.append(Mamba( d_model=self.hparams.get("d_model", config.d_model), expand_factor=self.hparams.get("expand_factor", config.expand_factor), d_conv=self.hparams.get("d_conv", config.d_conv), ) )See: https://github.com/state-spaces/mamba for further details on the original implementation.
Thank you for your advice! I have tried the original Mamba implementation as suggested, and the training speed has significantly improved. It is now similar in speed to FT-Transformer and ResNet :).
V1.0.0 now includes both, Mamba1 and Mamba2 from the mamba-ssm package. I will close this issue, but feel free to reopen when there are further issues :)