DeepOBS icon indicating copy to clipboard operation
DeepOBS copied to clipboard

Cifar100 wrn404 regularization

Open f-dangel opened this issue 4 years ago • 0 comments

Problem: cifar100_wrn404 forward pass is not working when adding regularization.

This is caused by a missing return statement that creates the regularization groups.

Demo:

  • Before: Forward pass without regularization passes, forward pass with regularization raises AttributeError: 'NoneType' object has no attribute 'items'
  • After: Both forward passes work
from deepobs.pytorch.testproblems import cifar100_wrn404

tp = cifar100_wrn404(5)
tp.set_up()
tp.train_init_op()

# works
loss, _ = tp.get_batch_loss_and_accuracy(add_regularization_if_available=False)
print(loss)

# crashes
regularized_loss, _ = tp.get_batch_loss_and_accuracy(
    add_regularization_if_available=True
)
print(regularized_loss)

Minor:

  • Moved testproblem set_up method to __init__ of test.

f-dangel avatar Aug 03 '20 16:08 f-dangel