PyTorch-VAE
PyTorch-VAE copied to clipboard
Recommended way to load the model after training?
For example I guess this is one way?
from experiment import VAEXperiment
config = yaml.safe_load(open('configs/vae.yaml'))
ckpt = torch.load('logs/VanillaVAE/version_1/checkpoints/last.ckpt')
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_state_dict(ckpt['state_dict'])
Then one can access the model via experiment.model
It took me a while to figure this out. Maybe add such instructions to the README?
Hello! What model did you pass to the experiment?
Nevermind, it works with:
from experiment import VAEXperiment
import yaml
import torch
from models import *
config = yaml.safe_load(open('./configs/bbvae.yaml'))
model = vae_models[config['model_params']['name']](**config['model_params'])
ckpt = torch.load('./logs/BetaVAE/version_0/checkpoints/last.ckpt')
experiment = VAEXperiment(model, config['exp_params'])
experiment.load_state_dict(ckpt['state_dict'])
where I used the BetaVAE model.