pycox icon indicating copy to clipboard operation
pycox copied to clipboard

How can I save and load torchtuple model ?

Open kyuchoi opened this issue 2 years ago • 5 comments

First of all, thank you for your great works and your replies !!

Can I save and load torchtuple model using model.save_model_weights() and model.load_model_weights() ? It seems possible, however when I save and load model, I get the model as NoneType() error, not torchtuples model. Could you please give me a snippet here? I need this because the model is too large and need to be re-loaded in other .py file. Many thanks

kyuchoi avatar Oct 25 '21 12:10 kyuchoi

Thanks for the kind words!

See #29. and also this comment but I guess I should type out and answer here in case anyone else asks.

Alternative 1

Simplest (but not as stable):

model = LogisticHazard(net)

# save the whole network
model.save_net('mynet.pt')

# load the whole network into existing model
model.load_net('mynet.pt')

# make new model with stored network
new_model = LogisticHazard('mynet.pt')

For the semi-parametric Cox models, the baseline hazards are pickled. This means that the stored Cox model might not be readable for new versions of pickle/python. Ideally this should be handled in a different way.

Alternative 2

Only storing weights. This does not store the networks so they need to be loaded into a network with the correct architecture.

model.save_model_weights('myweights.pt')

model.load_model_weights('myweights.pt')

havakv avatar Oct 30 '21 13:10 havakv

Hi Havard, when I want to save the model I get this error. Do you know what is the problem?

File ~/Documents/survival_analysis/survival_estimate.py:52 in model.save_net('mynet.pt')

File ~/anaconda3/envs/myenv/lib/python3.9/site-packages/torchtuples/base.py:681 in save_net return torch.save(self.net, path, **kwargs)

File ~/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/serialization.py:380 in save _save(obj, opened_zipfile, pickle_module, pickle_protocol)

File ~/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/serialization.py:589 in _save pickler.dump(obj)

AttributeError: Can't pickle local object 'train_survmodel..Net'

mahootiha-maryam avatar Jun 15 '22 10:06 mahootiha-maryam

Hm, not sure. It seems to try to save some 'train_survmodel..Net'. What is that? Which model is this?

havakv avatar Jun 16 '22 06:06 havakv

train_survmodel is a function that I make my net in this function then I fit the model with Logistichazard and finally I return the model after that I try to save the model and then I faced the Pickle Error.

    net = Encoder_FC()
    model = LogisticHazard(net, tt.optim.Adam(0.01), duration_index=labtrans.cuts)
    model.set_device(torch.device("cuda:0"))
    epochs = 50
    log = model.fit_dataloader(dl_train, epochs, val_dataloader=dl_val)
    _ = log.plot()
    
    return model
newmodel = train_survmode(labtrans,dl_train,dl_val)
newmodel.save_net('mynet.pt')

Is there another way for saving the model? Because my net is deep and my images are huge. I can not train them every time.

mahootiha-maryam avatar Jul 02 '22 16:07 mahootiha-maryam

Hmm, you might need to move your model to the CPU before you can save it. Alternatively you should try saving the network parameters manually with torch functions. If you look at the source code you can see that we're just calling a simple torch function

    def save_net(self, path, **kwargs):
        """Save self.net to file (e.g. net.pt).
        Arguments:
            path {str} -- Path to file.
            **kwargs are passed to torch.save
        Returns:
            None
        """
        return torch.save(self.net, path, **kwargs)

An alternative and more robust approach is to only store the weights with save_model_weights code which just makes the call

    def save_model_weights(self, path, **kwargs):
        """Save the model weights.
        Parameters:
            path: The filepath of the model.
            **kwargs: Arguments passed to torch.save method.
        """

This doesn't store the network structure, so you to load the weights you first need to build the network, and then call load_model_weights

havakv avatar Jul 21 '22 07:07 havakv