`LightningModule.on_train_batch_end` executes after `Callback.on_train_batch_end`
Outline & Motivation
Hi! In my LightningModule, I have a on_train_batch_end which performs some post-processing of the results from the training_step (for example logits normalisation and update metrics). Now I can callbacks with also on_train_batch_end method, which expects the model outputs to be normalised.
However, I find out that the execution order is "LightningModule.training_step -> Callback.on_train_batch_end -> LightningModule.on_train_batch_end", while I was expecting to be "LightningModule.training_step -> LightningModule.on_train_batch_end -> Callback.on_train_batch_end", which seems more natural to me.
I don't want to have this normalisation process in the training_step, because I want to have it as an abstractumethod so children classes overwrite the training_step and the on_train_batch_end can be supported out-of-the-box.
Concreate example:
Binary Segmentation model, LightningModule.training_step produces the model logits, in the LightningModule.on_train_batch_end I define how these will be denormalized and pass them to the metrics, and I have a callback SegmentationWriter which expects the predictions to be denomalized, to visualise and save them on disk. Note that I want to describe and execute the denormalisation process only once, so that the metrics and the segmentation writer will use the same one (and I dont want to repeat myself in the segmentation callback)
Question: is it intended or a specific reason which the on_train_batch_end callback method come first from the LightningModule, as for me it feels more natural the otherway around? Can I alter this order?
Thank you!
Pitch
No response
Additional context
No response
cc @justusschock @awaelchli @carmocca @borda
Hi @ioangatop Thanks for the question.
Question: is it intended or a specific reason which the on_train_batch_end callback method come first from the LightningModule, as for me it feels more natural the otherway around? Can I alter this order?
I don't remember if there ever was a good reason, but at some point we had to standardize it so that the order is consistent, and so we ended up with callback first, then module. This was the case from very early on. For certain callbacks, it would indeed be advantageous to run after the LM hook, as can be seen here: https://github.com/Lightning-AI/lightning/blob/e24620c1af707e299d43c30732ce62b2fcb59df8/src/lightning/pytorch/loops/fit_loop.py#L361-L367
Arguing for a change in order is probably hard, considering that it would be a severe breaking change. It has also been suggested to make the order configurable, possibly globally. @carmocca also has a good explanation here on this topic: https://github.com/Lightning-AI/lightning/issues/17131#issue-1629709032
For your use case, could you use a hook that runs after on_train_batch end? So then it would be guaranteed to run after LM.on_train_batch.
Hi @awaelchli thank you for your quick response!
My issue is that if I do this (that was my original implementation):
class AbstractModelModule(pl.LightningModule):
def __init__(self, on_step_end_processes: Callable):
self.on_step_end_processes = on_step_end_processes
def on_validation_step(self, ...):
pass
def validation_step(self, ...):
outputs = self.on_validation_step(...)
self._on_validation_step_end(outputs, ...)
return outputs
def _on_validation_step_end(self, outputs,...) -> None:
self.on_step_end_processes(outputs)
class ModelModuleOne(AbstractModelModule):
...
def on_validation_step(self, ...):
...
Here in the ModelModuleOne it will always try to initialise a val_dataloader even maybe I don't want one (for example in some self-supervised setting) so I have to provide one (because lightning sees that the validation_step is present which its true due to the parent, even though I dont want necessarily to use one). however if I had the validation_step empty in the parent class, then it will be responsibility of the child to have one, and it if does, automatically the on_validation_batch_end kicks in and does the postporcesses (metric logs, normalisations etc). One solution is to maybe have to parents classes one with train AbstractModelModule and then one with val AbstractModelModuleWithValidation(AbstractModelModule) but it gets a bit out-of-hand IMO. The last thing is I don't get the nice support of the validation_step out-of-the-box (error messages, standard lightning syntax, etc). Maybe you have another solution here, something maybe that I haven't thought about?
One solution (and my favourite) could be to have a on_{train,validation,test,predict}_step_end support in LM natively, which will happen before the on_XXX_batch_end regardless callbacks or LM. So it would be LM.on_{train,validation,test,predict}_step_end -> Callback.on_{train,validation,test,predict}_batch_end -> LM.on_{train,validation,test,predict}_batch_end. What do you think, will it complicate the API a lot?
Another one could be to use the above structure, and that I configure somewhere that if the fit command is used and the validation_step is present but the val_dataloader is None, then skip and don't raise error - I dont know what the repercussion of these tbh and then operations like --trainer.fast_dev_run does not work
@awaelchli fyi this is the code black which I'm interested in https://github.com/Lightning-AI/lightning/blob/fbdbe632c67b05158804b52f4345944781ca4f07/src/lightning/pytorch/loops/evaluation_loop.py#L408-L410
@ioangatop It looks like your inheritance structure violates the liskov substitution principle. Instead of bending the hook system to workaround its limitations, here's two simpler alternatives
- set
limit_val_batches=0to disable validation for the models where you don't want to run it - "erase" the inherited validation step by monkey patching it either on the class or the instance:
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities.model_helpers import is_overridden
class MyModel(BoringModel):
...
model = MyModel()
# configure_optimizers is inherited
print(is_overridden("configure_optimizers", model))
# override the instance
model.configure_optimizers = None
print(is_overridden("configure_optimizers", model))
class MyModel(BoringModel):
# override the class
configure_optimizers = None
model = MyModel()
print(is_overridden("configure_optimizers", model))