pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Reload new data after one epoch doesnt work.

Open KralaBenjamin opened this issue 3 years ago • 1 comments

🐛 Bug

To Reproduce

https://colab.research.google.com/drive/1V0LDGDjW_Ettv0q6fQXFtnd-j2aIURDm?usp=sharing Just follow the boring model

Expected behavior

tldr: The trainer should use or reload the data from the class npDataModule. The trainer doesnt do it because it never reaches the breakpoints there.

Environment

  • Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
  • PyTorch Lightning Version (e.g., 1.5.0):
  • Lightning App Version (e.g., 0.5.2):
  • PyTorch Version (e.g., 1.10):
  • Python version (e.g., 3.9):
  • OS (e.g., Linux):
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • How you installed PyTorch (conda, pip, source):
  • If compiling from source, the output of torch.__config__.show():
  • Running environment of LightningApp (e.g. local, cloud):
  • Any other relevant information:
  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 11.3
  • Packages:
    • lightning: None
    • lightning_app: None
    • numpy: 1.21.6
    • pyTorch_debug: False
    • pyTorch_version: 1.12.0+cu113
    • pytorch-lightning: 1.7.0
    • tqdm: 4.64.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.13
    • version: #1 SMP Sun Apr 24 10:03:06 PDT 2022

Additional context

I would like to carry out Adversarial Defence Training. In the first epoch I load a data module and from the second epoch onwards I want to inject modified data into the training process at the beginning of each epoch.

I have three different approaches and they all don't work.

1.) The feature of the

reload_dataloaders_every_n_epochs

of the trainer class. According to the documentation I should use in my model (i.e. a LightningModule) the function
train_dataloader(), val_dataloader() etc., from which the trainer reloads the new dataloaders. This does not work. 2.) the usage of the functions

trainer.reset_train_dataloader()
trainer.reset_val_dataloader()

Here it is not clear what is being loaded from. (I find the documentation very sparse). It turns out that a new training_dataloader has been loaded of type CombinedLoader, which is not what I wanted. (val_dataloader remains None). 3) I manually set datamodule or train_dataloader etc. in the trainer object. This does not work either.

I am using the current version 1.7. I know if I'm missing something (I can't find anything in the documentation), but I have a simple workflow from a data perspective (just new data) and I can't find people commenting about similar problems. I'm not sure if it's a bug or I'm missing implicit assumptions. I would still appreciate help as I have been sitting on this problem for many work days and am very desperate. Thank you very much!

cc @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj

KralaBenjamin avatar Aug 08 '22 16:08 KralaBenjamin

Hi @KralaBenjamin

You are passing dataloaders directly to fit, but then wish to exchange them later on with a datamodule. This is not supported and not recommended. I can sketch you a possible way to do this:


class YourDatamodule(pl.LightningDataModule):
    def __init__(self, ...):
        ...

        self._train_data = ...
    
    def train_dataloader(self):  
        ...
        return DataLoader(self._train_data, ...)


class AdversarialDefence(pl.Callback):
   ...

    def on_train_epoch_end(self, trainer, model):
        if trainer.current_epoch % self.attack_after_n_epochs == 0:
            new_data = ...
            trainer.datamodule._train_data = new_data



def run():

    model = BoringModel()
    datamodule = YourDatamodule()

    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        num_sanity_val_steps=0,
        max_epochs=5,
        enable_model_summary=False,
        reload_dataloaders_every_n_epochs=1,
        callbacks=[AdversarialDefence(0.5)]
    )
    trainer.fit(model, datamodule)  # <--- here pass in the datamodule

You can also define some helper methods in your DataModule to shift the responsibility of updating the data from the callback to the datamodule. That might help.

In the above solution, the Trainer can rely on calling the train_dataloader() method on the datamodule, which handles taking your data and configuring the DataLoader object. Then you can just focus on updating your dataset object from either the callback or from within the datamodule.

I haven't looked closely into your code, but I think it would also be possible to create your adversarial data samples online by replacing existing data on the fly with the new one.

awaelchli avatar Aug 09 '22 12:08 awaelchli

Thank you very much, it works now.

KralaBenjamin avatar Aug 11 '22 14:08 KralaBenjamin