pytorch-widedeep
pytorch-widedeep copied to clipboard
Problems running transformer models
Hello,
I'm trying to classify events for a dark matter direct detection experiment which are tabulated in some optimal features (data are continuous). When I run both xgboost and lgbm algorithms I get AUCs about 0.98. When I run an MLP model (without optimisation) I get about 0.93 which is a bit far from the decision trees but maybe this is the best one can get with an MLP. The issue comes with the transformer models. From those I get just like a random classifier (~0.5) so there must be wrong in my script but it is not obvious to me identify the issue. Could you pls have a look at my script and tell me if you see something wrong ? This is my script:
import numpy as np import torch import pandas as pd
from pytorch_widedeep.initializers import XavierNormal from pytorch_widedeep import Trainer from pytorch_widedeep.models import ( SAINT, Wide, WideDeep, TabPerceiver, FTTransformer, TabFastFormer, TabTransformer, ) from pytorch_widedeep.metrics import Accuracy, Precision from pytorch_widedeep.datasets import load_adult from pytorch_widedeep.callbacks import ( LRHistory, EarlyStopping, ModelCheckpoint, ) from pytorch_widedeep.preprocessing import TabPreprocessor from pytorch_widedeep.initializers import XavierNormal, KaimingNormal
#from torchmetrics import AUROC from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, roc_auc_score, roc_curve
use_cuda = torch.cuda.is_available()
if name == "main":
csv_file_path = '/lustre/ific.uv.es/ml/ific005/projects/direct_detection/data_dd/events2024/processed/combined_data.csv'
# Load the CSV file into a DataFrame
df = pd.read_csv(csv_file_path)
print(df.head())
continuous_cols = ['pA_S1','pH_S1','pHT_S1','pL_S1','pL90_S1','pRMSW_S1','pHTL_S1','pA_S2','pH_S2','pHT_S2','pL_S2','pL90_S2','pRMSW_S2','pHTL_S2','pbot', 'ptop','pdiffT']
target = "Label"
df_train, df_valid = train_test_split(
df, test_size=0.2, stratify=df[target], random_state=1
)
df_valid, df_test = train_test_split(
df_valid, test_size=0.5, stratify=df_valid[target], random_state=1
)
tab_preprocessor = TabPreprocessor(
continuous_cols=continuous_cols,
scale=True,
with_attention=True,
)
X_tab_train = tab_preprocessor.fit_transform(df_train)
X_tab_valid = tab_preprocessor.transform(df_valid)
X_tab_test = tab_preprocessor.transform(df_test)
# target
y_train = df_train[target].values
y_valid = df_valid[target].values
y_test = df_test[target].values
wide = Wide(input_dim=np.unique(X_tab_train).shape[0], pred_dim=1)
tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
continuous_cols=continuous_cols,
embed_continuous=True,
n_blocks=4
)
saint = SAINT(
column_idx=tab_preprocessor.column_idx,
continuous_cols=continuous_cols,
cont_norm_layer="batchnorm",
n_blocks=4,
)
tab_perceiver = TabPerceiver(
column_idx=tab_preprocessor.column_idx,
continuous_cols=continuous_cols,
n_latents=6,
latent_dim=16,
n_latent_blocks=4,
n_perceiver_blocks=2,
share_weights=False,
)
tab_fastformer = TabFastFormer(
column_idx=tab_preprocessor.column_idx,
continuous_cols=continuous_cols,
n_blocks=4,
n_heads=4,
share_qv_weights=False,
share_weights=False,
)
ft_transformer = FTTransformer(
column_idx=tab_preprocessor.column_idx,
continuous_cols=continuous_cols,
input_dim=32,
kv_compression_factor=0.5,
n_blocks=3,
n_heads=4,
)
for tab_model in [
tab_transformer,
saint,
ft_transformer,
tab_perceiver,
tab_fastformer,
]:
model = WideDeep(deeptabular=tab_model, pred_dim=1)
wide_opt = torch.optim.Adam(model.parameters(), lr=0.01)
deep_opt = torch.optim.Adam(model.parameters(), lr=0.01)
wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=3)
deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=5)
optimizers = {"wide": wide_opt, "deeptabular": deep_opt}
schedulers = {"wide": wide_sch, "deeptabular": deep_sch}
initializers = {"wide": KaimingNormal, "deeptabular": XavierNormal}
callbacks = [
LRHistory(n_epochs=10),
EarlyStopping(patience=5),
ModelCheckpoint(filepath="model_weights/wd_out"),
]
metrics = [Accuracy]
trainer = Trainer(
model,
objective="binary",
optimizers=optimizers,
lr_schedulers=schedulers,
initializers=initializers,
callbacks=callbacks,
metrics=metrics,
)
trainer.fit(
X_train={"X_tab": X_tab_train, "target": y_train},
X_val={"X_tab": X_tab_valid, "target": y_valid},
n_epochs=10,
batch_size=100,
)
df_pred = trainer.predict(X_tab=X_tab_test)
print(classification_report(df_test[target].to_list(), df_pred))
#print("Actual predicted values:\n{}".format(np.unique(df_pred, return_counts=True)))
auc = roc_auc_score(df_test[target], df_pred)
print('AUC', auc)
Thanks a lot !
Roberto
Hey @rruizdeaustri
I will have a more detail look, but in general here are some comments:
- Use a simpler model, forget about the wide component and use simply a deeptabular component with defaults. (review the code in your example since the
optimizersandscheduresare not correctly defined. TheTrainernot throwing an error is intentional, I might change it, but just define your Trainer as
trainer = Trainer(
model,
objective="binary",
callbacks=[ModelCheckpoint(filepath="model_weights/wd_out")],
metrics=[Accuracy],
)
- The results with Transformer based models depend A LOT on the parameters, far more than in GBMs, where all, XGBoost, LightGBM and CatBoost perform almost to their best performance out of the box. You could have a look to this relatively old post see if it helps
I hope this helps and let me know how you get on with this, see if I can help more
Hi @jrzaurin,
I have made the modifications you suggested and results make more sense now. I'm optimising hyper-parameters with optima in resnet and transformer models but the results are far from the one got with LightGMB: AUC ~0.93 versus ~ 0.98 for lgqbm
Thanks !
Rbt
Hey @rruizdeaustri
Thanks for sharing the results :)
0.05 is perhaps a bit too much, maybe I can look at some examples if you would be willing to share them. However, I am afraid that this is the "brutal" reality for most (true) real world cases when it comes to DL vs GBMs.
You could try some other libraries see if their implementations are better or you get better results (?)
In my experience I have used DL for tabular data in a few occasions, but never aimed to beat GBMs, since I knew was a lost battle.
Hi @jrzaurin,
Yes, these are too much differences !
I could share with you the files I'm using to train as well as the data if you like. Let me know !
Thanks !
Hey @rruizdeaustri !
I am traveling at the moment, but if you join the slack channel we can move the conversation there and we can share the files. See if I have the time to give it a go myself! :)
Thanks!