avalanche icon indicating copy to clipboard operation
avalanche copied to clipboard

Transformations not applied on Train DataLoader

Open evertonaleixo opened this issue 2 years ago • 3 comments

🐛 Describe the bug When you add a transformation on a dataset and create a DataLoader in the training loop, the transformations are not applied on items of this DataLoader.

🐜 To Reproduce

from re import X
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from avalanche.models import SimpleMLP
from avalanche.logging import TextLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training.strategies import BaseStrategy
from avalanche.training.plugins import StrategyPlugin

from avalanche.benchmarks.classic import RotatedMNIST

from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from torchvision.transforms import Normalize, ToTensor

class SimplePlugin(StrategyPlugin):
    def __init__(self, debug:bool=False):
        self.debug = debug

    def after_train_dataset_adaptation(self, strategy: BaseStrategy, **kwargs):
        def x_transform(x):
            print('call x_transform',type(x), x)
            x = ToTensor()(x)
            return Normalize((0.1307,), (0.3081,))(x)
        
        def y_transform(y):
            print('call y_transform')
            return y

        strategy.adapted_dataset = strategy.adapted_dataset.replace_transforms(x_transform, y_transform)


model = SimpleMLP()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = CrossEntropyLoss()
benchmark = RotatedMNIST(
    n_experiences=2,
    seed=1234
)

evaluator = EvaluationPlugin(
    accuracy_metrics(epoch=True, experience=True, stream=True),
    loss_metrics(epoch=True, experience=True, stream=True),
    benchmark=benchmark,
    strict_checks=False,
    loggers=[TextLogger(open('log.txt', 'a'))]
)

plugin = SimplePlugin(debug=False)

strategy = BaseStrategy(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_mb_size=64,
    device='cpu',
    eval_mb_size=64,
    train_epochs=1,
    eval_every=1,
    plugins=[plugin],
    evaluator=evaluator,
)

strategy.train(benchmark.train_stream)

🐝 Expected behavior The transformation should happen in data when loaded from DataLoader in the training loop. Therefore, many prints have to be shown. It is happening because in TaskBalancedDataLoader uses the task_set and not the adaped_dataset (data). It seems to be different.

From file: data_loader.py

task_datasets = []
for task_label in self.data.task_set:
    tdata = self.data.task_set[task_label]
    task_datasets.append(tdata)

A workaround is to pass the transformation to the dataset in task_set.

for task_label in self.data.task_set:
    tdata = self.data.task_set[task_label]
    tdata = tdata.replace_transforms(
        self.data.transform,
        self.data.target_transform,
    )
    task_datasets.append(tdata)

evertonaleixo avatar Apr 25 '22 23:04 evertonaleixo

It is happening because in TaskBalancedDataLoader uses the task_set and not the adaped_dataset (data). It seems to be different.

This is the problem. Ideally, the task_set should keep the AvalancheDatasets transformations. @lrzpellegrini is there any reason behind this choice or is it just a bug?

AntonioCarta avatar Apr 26 '22 07:04 AntonioCarta

It works fine with the latest commit in the master branch. I think that the issue is related to an older version of Avalanche.

This is the code adapted for the latest version (notice that some class names have been changed):

from re import X
from torch.nn import CrossEntropyLoss
from torch.optim import SGD

from avalanche.core import SupervisedPlugin
from avalanche.models import SimpleMLP
from avalanche.logging import TextLogger
from avalanche.training.plugins import EvaluationPlugin

from avalanche.benchmarks.classic import RotatedMNIST

from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from torchvision.transforms import Normalize, ToTensor

from avalanche.training.templates import SupervisedTemplate


class SimplePlugin(SupervisedPlugin):
    def __init__(self, debug: bool = False):
        super(SimplePlugin, self).__init__()
        self.debug = debug

    def after_train_dataset_adaptation(self, strategy, **kwargs):
        def x_transform(x):
            print('call x_transform', type(x), x)
            x = ToTensor()(x)
            return Normalize((0.1307,), (0.3081,))(x)

        def y_transform(y):
            print('call y_transform')
            return y

        strategy.adapted_dataset = strategy.adapted_dataset.replace_transforms(
            x_transform, y_transform)


def main():
    model = SimpleMLP()
    optimizer = SGD(model.parameters(), lr=1e-3)
    criterion = CrossEntropyLoss()
    benchmark = RotatedMNIST(
        n_experiences=2,
        seed=1234
    )

    evaluator = EvaluationPlugin(
        accuracy_metrics(epoch=True, experience=True, stream=True),
        loss_metrics(epoch=True, experience=True, stream=True),
        benchmark=benchmark,
        strict_checks=False,
        loggers=[TextLogger(open('log.txt', 'a'))]
    )

    plugin = SimplePlugin(debug=False)

    strategy = SupervisedTemplate(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        train_mb_size=64,
        device='cpu',
        eval_mb_size=64,
        train_epochs=1,
        eval_every=1,
        plugins=[plugin],
        evaluator=evaluator,
    )

    strategy.train(benchmark.train_stream)


if __name__ == '__main__':
    main()

@AntonioCarta we should consider adding aliases for the legacy names so that they map to the new classes. Something like BaseStrategy -> SupervisedTemplate, StrategyPlugin -> SupervisedPlugin.

lrzpellegrini avatar Apr 26 '22 14:04 lrzpellegrini

@AntonioCarta we should consider adding aliases for the legacy names so that they map to the new classes. Something like BaseStrategy -> SupervisedTemplate, StrategyPlugin -> SupervisedPlugin.

The problem is that these may be very misleading, because the actual base class currently is very different from the supervised template. At least an explicit error is easy to undrestand.

AntonioCarta avatar Apr 26 '22 15:04 AntonioCarta