pytorch_tabular icon indicating copy to clipboard operation
pytorch_tabular copied to clipboard

Loading weights trained in torch to pytorch tabular

Open JonathanBhimani-Burrows opened this issue 1 year ago • 3 comments
trafficstars

Is your feature request related to a problem? Please describe. I'm looking to load a pretrained torch model into one of the pytorch tabular models to test the effectiveness of pretrained checkpoints

Describe the solution you'd like solution might exist - could be a bug. I would like to be able to, similar to regular torch, do something such as image

Describe alternatives you've considered I've tested both of load_weights and load_model from https://pytorch-tabular.readthedocs.io/en/latest/tabular_model/#pytorch_tabular.TabularModel.load_weights Load_model seems to required additional config files, and load weights errors with AttributeError: 'TabularModel' object has no attribute 'model' when I run the following code

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
path = xxx

tabular_model = tabular_model.load_weights(path)
print(f" Weights loaded")

Additional context Add any other context or screenshots about the feature request here.

JonathanBhimani-Burrows avatar Jan 03 '24 23:01 JonathanBhimani-Burrows

When you initialize Tabular model, it doesn't really initialize the PyTorch model. To concretely initialize the Pytorch model, we need the data as well because Pytorch tabular inferes some information like output shape, input shape etc. from the data and initializes on the fly in fit.

What you can do is use the Low Level API and create the data module and the model, but not train. Once you have the model, you can load the weights onto that model before training

manujosephv avatar Jan 13 '24 00:01 manujosephv

Thanks! Turns out that the architecture of the two was slightly differently implemented, so it wasn't possible to load the weights (not just different state dict names, but entirely different layers) Thanks again for your help, much appreciated!

JonathanBhimani-Burrows avatar Jan 17 '24 14:01 JonathanBhimani-Burrows

Re-opening this Trained a model with Pytorch tabular (FT Trans) and saved

file_path = os.path.join(out_path, 'best_model') print(f"Saving Best Model at {file_path}") tabular_model.save_model(file_path)

Attempting to reload the model (checking if weights are updated)

tabular_model = TabularModel( data_config=data_config, model_config=model_config, optimizer_config=optimizer_config, trainer_config=trainer_config, # experiment_config=experiment_config )

_print(f"++++++++++++++++++++ Preparing Dataloader ++++++++++++++++++++")
datamodule = tabular_model.prepare_dataloader(
                train=train, validation=test, seed=42
            )
print(f"++++++++++++++++++++ Created Model ++++++++++++++++++++")
model = tabular_model.prepare_model(
            datamodule
        )
w1 = model._backbone.transformer_blocks.mha_block_0.mha.to_qkv.weight
w2 = model._backbone.transformer_blocks.mha_block_0.mha.to_out.weight


p = r'/content/drive/MyDrive/SKF/SKF/experiments/PytorchTabular/FT_Trans/results_Vsave_test/trialno_cat_bin_ftconfig19/best_model'
tabular_model.load_model(p, map_location='cpu', strict=True)
print(f"Model loaded ")

w1p = model._backbone.transformer_blocks.mha_block_0.mha.to_qkv.weight
w2p = model._backbone.transformer_blocks.mha_block_0.mha.to_out.weight

print(w1==w1p)
print(w2==w2p)_

Output: tensor([[True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], ..., [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True]]) tensor([[True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], ..., [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True], [True, True, True, ..., True, True, True]])

So it seems like the model weights aren't being loaded, and there's no error message associated

JonathanBhimani-Burrows avatar Jan 22 '24 17:01 JonathanBhimani-Burrows

Hi! Are you using the last version? I tested it here and it seems to be working. I used "model.save_model" to save and "TabularModel.load_model" to load the model. image

ProgramadorArtificial avatar Jan 27 '24 22:01 ProgramadorArtificial

Did you re-initialize the model in between these tests? Otherwise, the best model would be already loaded and you wouldn't see this issue I had to save, re-initialize and reload, and that's where I saw that the weights hadn't been updated

JonathanBhimani-Burrows avatar Jan 31 '24 14:01 JonathanBhimani-Burrows

Can you check this tutorial? This is the way you load a saved model, and in that case the model weights do get loaded and evaluation gives the same results.

Your workflow is slightly different, I guess.

  1. Train and save the model
  2. Initialize the Tabular model with same config, prepare the datamodule, and model
  3. Then load the saved model.

A few things that wrong with it.

  1. Step 1 is completely unnecessary if you are loading a saved model which is saved using save_model.
  2. You are initializing the model and storing it in model. Now the model isn't part of TabularModel. If you check tabular_model.has_model(), it will be False
  3. When you do load the model, in step 3, PyTorch Tabular loads the model and returns it. But you aren't catching the returned model.

Ideal workflow with save_model

#Trained a model with Pytorch tabular (FT Trans) and saved

file_path = os.path.join(out_path, 'best_model')
print(f"Saving Best Model at {file_path}")
tabular_model.save_model(file_path)

# Loading the model
loaded_model = TabularModel.load_model(p, map_location='cpu', strict=True)

manujosephv avatar Feb 01 '24 00:02 manujosephv

Step 1 and 2 were for replication information Regarding your workflow, that would make sense, thank you for clarifying! I had in my mind the regular pytorch approach( not lightning) where the model model isn't returned, it's loaded in place Thanks!

JonathanBhimani-Burrows avatar Feb 01 '24 20:02 JonathanBhimani-Burrows