pytorch_tabular
pytorch_tabular copied to clipboard
Loading weights trained in torch to pytorch tabular
Is your feature request related to a problem? Please describe. I'm looking to load a pretrained torch model into one of the pytorch tabular models to test the effectiveness of pretrained checkpoints
Describe the solution you'd like
solution might exist - could be a bug. I would like to be able to, similar to regular torch, do something such as
Describe alternatives you've considered I've tested both of load_weights and load_model from https://pytorch-tabular.readthedocs.io/en/latest/tabular_model/#pytorch_tabular.TabularModel.load_weights Load_model seems to required additional config files, and load weights errors with AttributeError: 'TabularModel' object has no attribute 'model' when I run the following code
tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
)
path = xxx
tabular_model = tabular_model.load_weights(path)
print(f" Weights loaded")
Additional context Add any other context or screenshots about the feature request here.
When you initialize Tabular model, it doesn't really initialize the PyTorch model. To concretely initialize the Pytorch model, we need the data as well because Pytorch tabular inferes some information like output shape, input shape etc. from the data and initializes on the fly in fit.
What you can do is use the Low Level API and create the data module and the model, but not train. Once you have the model, you can load the weights onto that model before training
Thanks! Turns out that the architecture of the two was slightly differently implemented, so it wasn't possible to load the weights (not just different state dict names, but entirely different layers) Thanks again for your help, much appreciated!
Re-opening this Trained a model with Pytorch tabular (FT Trans) and saved
file_path = os.path.join(out_path, 'best_model') print(f"Saving Best Model at {file_path}") tabular_model.save_model(file_path)
Attempting to reload the model (checking if weights are updated)
tabular_model = TabularModel( data_config=data_config, model_config=model_config, optimizer_config=optimizer_config, trainer_config=trainer_config, # experiment_config=experiment_config )
_print(f"++++++++++++++++++++ Preparing Dataloader ++++++++++++++++++++")
datamodule = tabular_model.prepare_dataloader(
train=train, validation=test, seed=42
)
print(f"++++++++++++++++++++ Created Model ++++++++++++++++++++")
model = tabular_model.prepare_model(
datamodule
)
w1 = model._backbone.transformer_blocks.mha_block_0.mha.to_qkv.weight
w2 = model._backbone.transformer_blocks.mha_block_0.mha.to_out.weight
p = r'/content/drive/MyDrive/SKF/SKF/experiments/PytorchTabular/FT_Trans/results_Vsave_test/trialno_cat_bin_ftconfig19/best_model'
tabular_model.load_model(p, map_location='cpu', strict=True)
print(f"Model loaded ")
w1p = model._backbone.transformer_blocks.mha_block_0.mha.to_qkv.weight
w2p = model._backbone.transformer_blocks.mha_block_0.mha.to_out.weight
print(w1==w1p)
print(w2==w2p)_
Output: tensor([[True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], ..., [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True]]) tensor([[True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], ..., [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True]])
So it seems like the model weights aren't being loaded, and there's no error message associated
Hi!
Are you using the last version? I tested it here and it seems to be working. I used "model.save_model" to save and "TabularModel.load_model" to load the model.
Did you re-initialize the model in between these tests? Otherwise, the best model would be already loaded and you wouldn't see this issue I had to save, re-initialize and reload, and that's where I saw that the weights hadn't been updated
Can you check this tutorial? This is the way you load a saved model, and in that case the model weights do get loaded and evaluation gives the same results.
Your workflow is slightly different, I guess.
- Train and save the model
- Initialize the Tabular model with same config, prepare the datamodule, and model
- Then load the saved model.
A few things that wrong with it.
- Step 1 is completely unnecessary if you are loading a saved model which is saved using
save_model. - You are initializing the model and storing it in
model. Now the model isn't part ofTabularModel. If you checktabular_model.has_model(), it will beFalse - When you do load the model, in step 3, PyTorch Tabular loads the model and returns it. But you aren't catching the returned model.
Ideal workflow with save_model
#Trained a model with Pytorch tabular (FT Trans) and saved
file_path = os.path.join(out_path, 'best_model')
print(f"Saving Best Model at {file_path}")
tabular_model.save_model(file_path)
# Loading the model
loaded_model = TabularModel.load_model(p, map_location='cpu', strict=True)
Step 1 and 2 were for replication information Regarding your workflow, that would make sense, thank you for clarifying! I had in my mind the regular pytorch approach( not lightning) where the model model isn't returned, it's loaded in place Thanks!