pytorch-lightning
pytorch-lightning copied to clipboard
Support A Variable Number of Batches
Description & Motivation
I have a customized batch sampler that has an undetermined number of batches (to maximize the use of GPU memory with variable-sized samples). I believe this behavior is supported by vanilla PyTorch. However, in Lightning, the number of batches is precalculated, as shown here: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/loops/fit_loop.py#L254
Pitch
Support a variable number of batches by calculating the batch limit at the beginning of each epoch or at the end of last epoch.
Alternatives
No response
Additional context
No response
cc @borda