PyTorch-VAE icon indicating copy to clipboard operation
PyTorch-VAE copied to clipboard

Recommended way to load the model after training?

Open peacej opened this issue 3 years ago • 2 comments

For example I guess this is one way?

from experiment import VAEXperiment
config = yaml.safe_load(open('configs/vae.yaml'))
ckpt = torch.load('logs/VanillaVAE/version_1/checkpoints/last.ckpt')
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_state_dict(ckpt['state_dict'])

Then one can access the model via experiment.model

It took me a while to figure this out. Maybe add such instructions to the README?

peacej avatar Aug 22 '22 12:08 peacej

Hello! What model did you pass to the experiment?

tudorjnu avatar Nov 24 '22 09:11 tudorjnu

Nevermind, it works with:

from experiment import VAEXperiment
import yaml
import torch 
from models import *

config = yaml.safe_load(open('./configs/bbvae.yaml'))
model = vae_models[config['model_params']['name']](**config['model_params'])
ckpt = torch.load('./logs/BetaVAE/version_0/checkpoints/last.ckpt')
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_state_dict(ckpt['state_dict'])

where I used the BetaVAE model.

tudorjnu avatar Nov 24 '22 10:11 tudorjnu