`trainer.fit_loop.setup_data()` does not refresh train dataset in `LightningModule`
Bug description
PR #16726 replaces the reset_*_dataloader() method calls with the respective Loop.setup_data() calls. This is also mentioned in the migration guide.
However, on versions <= 1.9, calling reset_train_dataloader() would reinstantiate the dataloader from a LightningModule's train_dataloader() method. This behaviour is now gone.
My specific use case is that I need to update the dataset of my model during training. I then use on_train_epoch_end() or a similar hook to call reset_train_dataloader(), to have the updated dataset in the next training epoch. I posted a minimal example below. You can run this example on both v1.9 and v2.0 to see the exact difference. v1.9 runs without problems, whereas v2.0 fails the second assertion in training_step(). I tested it on a fresh conda env install of both versions using python 3.10.
In case I am using the wrong loop to call setup_data() or am using the new interface incorrectly, please let me know. In that case I would also recommend providing some more hints in the migration guide or on PR #16726 since the current advice is not exactly clear. (i.e. which loops are "top level"?)
What version are you seeing the problem on?
2.0+
How to reproduce the bug
try:
import lightning
except ModuleNotFoundError:
import pytorch_lightning as lightning
import torch
from torch.utils.data import DataLoader, TensorDataset
class Model(lightning.LightningModule):
def __init__(self):
super().__init__()
self.train_data = TensorDataset(torch.zeros(1, 1))
def configure_optimizers(self):
return None
def on_train_epoch_end(self):
self.train_data = TensorDataset(torch.ones(1, 1))
if int(lightning.__version__[0]) < 2:
# for version < 2.0 (works)
self.trainer.reset_train_dataloader()
else:
# for version >= 2.0 (does not work)
self.trainer.fit_loop.setup_data()
def train_dataloader(self):
return DataLoader(
self.train_data,
)
def training_step(self, batch, batch_idx):
# de-tuple
batch = batch[0]
if self.trainer.global_step == 0:
assert torch.allclose(batch, torch.zeros_like(batch))
else:
# this assertion fails on lightning v2.0
assert torch.allclose(batch, torch.ones_like(batch))
return torch.tensor(0.0, requires_grad=True)
model = Model()
trainer = lightning.Trainer(max_steps=2)
trainer.fit(model)
Error messages and logs
File "/home/lars/code/python/lightning-trainable/playground.py", line 38, in training_step
assert torch.allclose(batch, torch.ones_like(batch))
AssertionError
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): Trainer, FitLoop, LightningModule
#- PyTorch Lightning Version (e.g., 1.5.0): 1.9 / 2.0
#- Lightning App Version (e.g., 0.5.2): -
#- PyTorch Version (e.g., 2.0): 1.9 / 2.0
#- Python version (e.g., 3.9): 3.10
#- OS (e.g., Linux): Ubuntu
#- CUDA/cuDNN version: 11.7
#- GPU models and configuration: RTX 2070
#- How you installed Lightning(`conda`, `pip`, source): conda
#- Running environment of LightningApp (e.g. local, cloud): local
More info
No response
cc @borda @justusschock @awaelchli @carmocca
You can do Trainer(reload_dataloaders_every_n_epochs=1) to accomplish this
You can do
Trainer(reload_dataloaders_every_n_epochs=1)to accomplish this
This solution is unsatisfactory since I want to a) avoid reloading every epoch and b) be able to reload irregularly and on command.
I also think you should re-add the bug tag, since the functionality of setup_data still seems broken to me, even if there technically is a different way to do this.
This is "working as expected" given the current design of setup_data, which doesn't run if the data is already setup and the trainer flag is not configured, see this early exit: https://github.com/Lightning-AI/lightning/blob/b2717f68789638f34bd9baca2d74b62a06c16ca9/src/lightning/pytorch/loops/fit_loop.py#L210-L211
If you make that if statement not trigger, you'll see your code passing. For example by adding trainer.fit_loop._combined_loader = None before you call setup_data
The easiest way to change this would be to add a force: bool flag to setup_data so that you can skip that logic, making this a feature
An idea for this part:
This solution is unsatisfactory since I want to a) avoid reloading every epoch and b) be able to reload irregularly and on command.
You could still set Trainer(reload_dataloaders_every_n_epochs=1) just so that the trainer calls the dataloader methods. In there, you can still decide whether you actually want to rebuild the dataloaders or just return the cached one:
def train_dataloader(self):
if condition:
# recreate
self.train_dl = DataLoader(
self.train_data,
)
return self.train_dl
👍 to this issue.
- The migration guide is pretty sparse on details. It just says to replace
trainer.reset_*_dataloader()withLoop.setup_data()which is vague. It took me some time to figure out how to actually invoke it, I had to set a breakpoint anddir()the Trainer and guess and check to arrive at tryingtrainer.fit_loop.setup_data(). More documentation would definitely help. - I agree with the original poster that the fact that this replacement does not, in fact, have the same functionality as the old one, is unexpected behavior and worth documenting as well.
- The suggested solution by @carmocca for adding
trainer.fit_loop._combined_loader = Noneseems to have done the trick for me. I also agree with the suggestion to make this a boolean flag passed toLoop.setup_data
@Borda @carmocca
Adding on this discussion, I also have a custom callback that was using the reset_xyz_dataloader that I'm migrating to Lightning (right now I'm using 2.1.3)
class FeatureExtractorCallback(Callback):
def __init__(self, devices, feature_extractor: nn.Module) -> None:
super().__init__()
self.devices = devices
self.feature_extractor = feature_extractor
@rank_zero_only
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when fit begins."""
if not hasattr(trainer, "datamodule"):
raise ValueError("Trainer must have a datamodule attribute.")
log.info("Extracting features!")
device = f"cuda:{self.devices[0]}" if (self.devices is not None and isinstance(self.devices, list)) else "cpu"
self.feature_extractor.to(device)
train_dataset = convert_feature_dataset(trainer.datamodule.train_dataloader(), self.feature_extractor, device)
val_dataset = convert_feature_dataset(trainer.datamodule.val_dataloader(), self.feature_extractor, device)
trainer.datamodule.train_dataset = train_dataset
trainer.datamodule.val_dataset = val_dataset
trainer._should_reload_train_dl = True
trainer._should_reload_val_dl = True
trainer.fit_loop.setup_data()
trainer.validate_loop.setup_data()
trainer._should_reload_train_dl = False
trainer._should_reload_val_dl = False
# trainer.reset_train_dataloader()
# trainer.reset_val_dataloader()
@rank_zero_only
def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Called when test begins."""
if not hasattr(trainer, "datamodule"):
raise ValueError("Trainer must have a datamodule attribute.")
log.info("Extracting features!")
device = f"cuda:{self.devices[0]}" if (self.devices is not None and isinstance(self.devices, list)) else "cpu"
self.feature_extractor.to(device)
test_dataset = convert_feature_dataset(trainer.datamodule.test_dataloader(), self.feature_extractor, device)
trainer.datamodule.test_dataset = test_dataset
trainer._should_reload_val_dl = True
trainer.test_loop.setup_data()
trainer._should_reload_val_dl = False
# trainer.reset_test_dataloader()
The on_fit_start part is working, training and validation datamodule is changed and reloaded and I receive batches coming from the newly created dataset correctly.
But test is not working! Debugging I see that the test_loop setup_data is called properly and the datamodule is loaded with the new batch (_combined_loader contains a reference to the correct dataset), but when I receive the batch in the test_step it is the one coming from the original test dataset and not from the updated one, any idea on this? It seems to be a bug to me but I can't get what's causing the issue