Enable Hyperparameter logging from any hook in the LightningModule
🚀 Feature
Make it possible to call save_hyperparameters from any hook in the LightningModule.
Motivation
Sometimes the dataset has hyperparameters that should be logged. However, the LightningDataModule is only accessible from the LightningModule once the trainer is initiated. Thus, it would be useful to call save_hyperparameters from on_fit_start, when the Trainer is specified and the hyperparameters from the dataset can easily be collected, e.g. through self.trainer.datamodule.build_hparams().
Pitch
log_hyperparameters shouldn't look for init args in the local variables when called outside the __init__ method.
Currently, this behaviour casues an exception in line 154 in utilities/parsing.py
local_args = {k: local_vars[k] for k in init_parameters.keys()}
because the function is looking for the init parameters in the local variables, which are only available when called from __init__.
Suggestion: Remove init parameter logging when called from other places.
Alternatives
Save init parameters and add them later.
Additional context
N/A
cc @borda @carmocca @justusschock @awaelchli @ananthsub @ninginthecloud @jjenniferdai @rohitgr7
Hello there!
Sometimes the dataset has hyperparameters that should be logged. However, the LightningDataModule is only accessible from the LightningModule once the trainer is initiated. Thus, it would be useful to call save_hyperparameters from on_fit_start
For this particular use case, I recommend calling self.save_hyperparameters() in the DataModule directly rather than via the LightningModule hooks. This is supported in the latest PL versions. The recorded parameters (merged with the ones form LM) get logged to the logger.
Apart from that, I think this is a honest feature request. However, let me remind everyone that the main objective of save_hyperparameters is NOT JUST to send the to the logger. This is a nice to have feature but the MAIN motivation of this method is to have the parameters saved to the checkpoint so that they can be used in the right way when there is a desire to load the model back from the checkpoint (via the LightningModule.load_from_checkpoint). From that perspective, it will be very error prone to let this method be called from every hook. It is imperative that the save_hyperparameters method captures exactly the arguments passed to the init, not more, not less and not any modified ones. For this reason, I recommend not going forward with this feature. Instead, we could figure out a better error handling. I'm curious what others think about this.
Finally, when you only care about logging some parameters, this is also possible with by accessing self.logger in any hook, (or self.log_dict).
@awaelchli Good point. Maybe it would be good to immediately raise an exception when save_hyperparameters is not called from __init__?
Yes, that sounds reasonable. Not trivial but doable. Anyone here up to give this a try?
I think it is quite a quick fix. Ill give it a go sometime this week.
@awaelchli what is the right kind of exception to raise for this? (https://github.com/PyTorchLightning/pytorch-lightning/commit/20ecd7626d9f8555b5d6216eedca3a569f563637)
I'd suggest RuntimeError
Updating the assignee as this is already being worked on in #13240.
I had a similar need as the OP, e.g. saving LightningDataModule hyperparamers with save_hyperparameters() at a specific point of the production pipeline using a LightningModule hook. While the API made it look like this worked as intended (no errors were raised) in my initial implementation, the requested hyperparameters didn't appear in the experiment manager after the run had finished.
From my perspective, the documentation or the API should be more explicit in how/where save_hyperparameters() should be used. Currently, the method's documentation in LightningDataModule links to the corresponding part in the LightningModule docs. However, the sentence stating that the method should be used "within your LightningModule’s init method" is not strictly true in this case, as calling the method seems to also work from LightningDataModule.setup(). As mentioned earlier, I had to find out later that calling the method from a LightningModule hook (on_train_start() more specifically) didn't work, so I think in this case the API should have warned me of improper use.
The same story in terms of what I had code-wise. I have the following method implemented in my LightningDataModule.
class LitDataModule(L.LightningDataModule):
...
def log_attributes(self) -> None:
logged_attributes = {
"att_1": value_1,
"att_2": value_2,
...
}
self.save_hyperparameters(logged_attributes)
... and I was calling the method from a LightingModule hook like so:
def on_train_start(self) -> None:
if self.trainer.datamodule:
self.trainer.datamodule.log_attributes()
In this case, the requested attributes didn't appear in the experiment manager, for which I don't have a clear explanation. I changed the logic afterwards to:
class LitDataModule(L.LightningDataModule):
def setup(self, stage: str = None) -> None:
...
if stage.state == "fit":
self.log_attributes()
... which seems to work at least for now, unless I run into something unexpected.
The purpose of save_hyperparameters() is to save the constructor arguments of the model class to the checkpoint, so that we can construct a model when we load the checkpoint. Usually you don't need to save the constructor arguments of the data module to the checkpoint.
The purpose of logging the hyperparameters to MLflow or Neptune is that you know exactly how the training was done, you can repeat it, and when you compare two training runs, you know exactly what was different. For this reason we would like to save not only the model and data arguments, but also the Trainer arguments.
Why do we log just the model's constructor arguments, instead of the whole configuration? At least with Lightning CLI, it seems to be as simple as:
class SaveConfigCallback:
def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
logger = trainer.logger
if isinstance(logger, Logger):
logger.log_hyperparams(self.config)