DeepOBS
DeepOBS copied to clipboard
Cifar100 wrn404 regularization
Problem: cifar100_wrn404
forward pass is not working when adding regularization.
This is caused by a missing return statement that creates the regularization groups.
Demo:
- Before: Forward pass without regularization passes, forward pass with regularization raises
AttributeError: 'NoneType' object has no attribute 'items'
- After: Both forward passes work
from deepobs.pytorch.testproblems import cifar100_wrn404
tp = cifar100_wrn404(5)
tp.set_up()
tp.train_init_op()
# works
loss, _ = tp.get_batch_loss_and_accuracy(add_regularization_if_available=False)
print(loss)
# crashes
regularized_loss, _ = tp.get_batch_loss_and_accuracy(
add_regularization_if_available=True
)
print(regularized_loss)
Minor:
- Moved testproblem
set_up
method to__init__
of test.