pytorch-image-models
pytorch-image-models copied to clipboard
Fix MultiEpochsDataLoader when not all is consumed
If for some reason MultiEpochsDataLoader.__iter__ gets called but in a previous call the caller didn't consume all its items, then it's gonna continue the previous one, which is incorrect. I propose this arguably inefficient solution as I can't think of another one, that's still better than an incorrect behavior.
An example of when this happens is when you just want to overfit the first batch of the dataset. For example, by using PyTorch Lightning's Trainer overfit_batches argument.
I added a warning. Note it should be logged only once.
Also, note this consumption is in the main thread.
(once per call)
@bryant1410 given it's not the cleanest workaround, and actually could have some crappy performance impact on large datasets (iterating through the whole dataset to reach the end), what are the other scenarios where this is needed? for overfit one batch is there any point to using this loader in the first place?
When overfitting a batch there's little point more than it being faster. Just a convenience to leave it there and then orthogonally enable a flag to overfit a batch or not.
But the main issue IMO is that it's incorrect, it's gonna fail silently and you don't know it, because it's an iterable that's behaving like an iterator but without the next method.
What about just raising a warning, without consuming the data (which is in the main thread)?