disent
disent copied to clipboard
[FEATURE]: Model Saving and Checkpointing
Is your feature request related to a problem? Please describe.
Model saving and checkpointing is currently disabled for experiment/run.py
This was due to old pickling errors and the extensive use of wandb for logging. Actual saved models were not needed at the time.
Describe the solution you'd like Re-enable model checkpointing, and allow continuing of training.
- Schedules will need to be pickled
- Add config options to enable saving and continuing training
Hi 👋 First of all thanks for the fantastic framework!
I started to work on this issue since I need model saving for my visualization project. Model saving seems to be very simple, however, loading a checkpoint appears to be a bit more complicated as far as I understand. I found that a saved checkpoint can be loaded but the model parameter for e.g. BetaVae
is an omegaconf.dictconfig.DictConfig
instead of an actual AutoEncoder
object if run with hydra
. I think this is because in run.py
the framework is created using hydra
and maybe some magic is happening there. It works fine if I save and load the models with the standard Python API like in your example. Will investigate...
Ah, it seems like someone hit the same roadblock before: https://github.com/Lightning-AI/lightning/discussions/6144
Hi @meffmadd, really glad you are finding it useful!
I managed to get away with wandb results a while back so I never got around to fixing this.
I'll investigate based on the information you provided and get back to you. Thank you for that!
You noted the object is an OmegaConf instance. I don't think it would break too much if we switch that over to a dictionary, and recursively convert all the values. There is a built in function for this.
(As for pytorch lightning, I have become a bit disillusioned towards it, as it has placed certain constraints on the framework that were never intended.)
To get to my question, how important is API stability for you right now?
API stability it is not a major concern for me. I only planned on using the hydra configs but maybe using the Python API is more useful for me (since it removes a layer of complexity). Yeah, frameworks are nice as long as their magic works and there is documentation 😅
For model saving when running with hydra
I think I found a simple workaround in the Vae
class:
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
checkpoint["hyper_parameters"]["model"] = self._model
return super().on_save_checkpoint(checkpoint)
This manually sets the model in the checkpoint and also works for loading! I can implement simple model saving and make a pull request with this if you like so no API changes are necessary.
And a quick unrelated question:
Could you tell me a config for beta-VAE that works well with dSprites? I use bce
loss but this somehow creates NaN values in the encoder output. If I use the norm_conv64 framework it works well but this seems to be not be standard as per your warning.
Created pull request #37
Thank you so much for the PR! I left a few comments about tests. We just need to make sure to add the new keys to the configs and (possibly) update the tests to tests the checkpointing.
As for your question. That should not be happening with the BCE loss. It may be due to largely unrelated things like the strength of the regularization term, or the learning rate too.
- I know dSprites is a binary dataset, but I would still recommend using MSE loss as the BCE loss has lead to prevalent errors in the VAE field. For example moving from
dsprites
todsprites-imagenet
or evenshapes3d
, where BCE loss is no longer applicable (and then results also don't transfer as well, needing different hparams).
EDIT: on this note another reason for the BCE loss failing could be due to the dataset normalization. I am not sure if that possibly has a part to play, as the output is also normalized. There may be a logic/precision error there.
- You can try disabling this in the dataset config
I will fix the configs now!
Thanks for your answer! I will try it with MSE but with a higher learning rate because when I tested the beta-VAE with MSE it did not converge at all.
I think possibly a lower beta value then too.
Closing this with your changes from:
- #37
Thank you for contributing!
Now released under v0.7.0