backpack icon indicating copy to clipboard operation
backpack copied to clipboard

Better error messages for BatchNorm

Open fKunstner opened this issue 3 years ago • 2 comments

Makes the BatchNorm error message more explicit to avoid confusions (https://github.com/f-dangel/backpack/issues/239) and adds an option to ignore the exception.

Summary of changes:

  • Make BatchNorm error message more explicit
  • Tie the BatchNorm raising of exception or warning to fail_mode
  • Split error handling on missing extensions for first and second order The current error handling for missing modules mixed second-order extensions (which should fail if there is no extension defined) and first-order extensions (which should fail if there is no extension defined and the module has parameters). Moved to first order and second order base classes.
  • Change default fail_mode to error and make fail_mode user accessible First order extensions did not expose fail_mode and had warn as a default.

fKunstner avatar Feb 10 '22 19:02 fKunstner

Hey Fred, I skimmed through your changes:

  1. The failing test checks if the result in batch_grad for a BN layer in train mode sums to grad. Working with batch_grad is 'okay' in this case because it's not interpreted as per-sample gradients. We could either revert the default for first-oder fail mode, or adapt the test to use BatchGrad(fail_mode-"WARNING"). I would currently favor to revert the default (as this also does not trigger a version bump, and fixes 2.).
  2. The RTD example with the custom ResNet fails for similar reasons as in 1.
  3. Can you pip install --upgrade && make black to update the formatting?

Happy to review or discuss!

f-dangel avatar Feb 15 '22 10:02 f-dangel

Thanks for the check!

I'd lean more towards crash that warn, but to get to something we can 👍; How about, starting from this setup;

  • Revert the default to fail_mode = Warning for first-order extensions
  • Change the error message to Use at your own risks
  • Add a notice that This is not supported and might throw an error in a future version?

The failing test checks if the result in batch_grad for a BN layer in train mode sums to grad. Working with batch_grad is 'okay' in this case because it's not interpreted as per-sample gradients

I don't follow the "batch_grad is okay". Do you mean in the context of the tests? If so I agree that BatchGrad should sum to Grad with or without batchnorm. But I don't think this should be the default behavior of the user-facing API. Someone calling batch_grad is expecting individual gradients and should get an error (maybe a strong warning works as well).

Can you pip install --upgrade && make black to update the formatting?

The files that black complains about are not part of this pr(?). I'll merge main in there again.

fKunstner avatar Feb 15 '22 18:02 fKunstner