backpack
backpack copied to clipboard
Better error messages for BatchNorm
Makes the BatchNorm error message more explicit to avoid confusions (https://github.com/f-dangel/backpack/issues/239) and adds an option to ignore the exception.
Summary of changes:
- Make BatchNorm error message more explicit
- Tie the BatchNorm raising of exception or warning to fail_mode
- Split error handling on missing extensions for first and second order The current error handling for missing modules mixed second-order extensions (which should fail if there is no extension defined) and first-order extensions (which should fail if there is no extension defined and the module has parameters). Moved to first order and second order base classes.
- Change default fail_mode to error and make fail_mode user accessible
First order extensions did not expose
fail_mode
and hadwarn
as a default.
Hey Fred, I skimmed through your changes:
- The failing test checks if the result in
batch_grad
for a BN layer in train mode sums tograd
. Working withbatch_grad
is 'okay' in this case because it's not interpreted as per-sample gradients. We could either revert the default for first-oder fail mode, or adapt the test to useBatchGrad(fail_mode-"WARNING")
. I would currently favor to revert the default (as this also does not trigger a version bump, and fixes 2.). - The RTD example with the custom ResNet fails for similar reasons as in 1.
- Can you
pip install --upgrade && make black
to update the formatting?
Happy to review or discuss!
Thanks for the check!
I'd lean more towards crash that warn, but to get to something we can 👍; How about, starting from this setup;
- Revert the default to
fail_mode = Warning
for first-order extensions - Change the error message to
Use at your own risks
- Add a notice that
This is not supported
andmight throw an error in a future version
?
The failing test checks if the result in batch_grad for a BN layer in train mode sums to grad. Working with batch_grad is 'okay' in this case because it's not interpreted as per-sample gradients
I don't follow the "batch_grad is okay". Do you mean in the context of the tests? If so I agree that BatchGrad should sum to Grad with or without batchnorm. But I don't think this should be the default behavior of the user-facing API. Someone calling batch_grad
is expecting individual gradients and should get an error (maybe a strong warning works as well).
Can you pip install --upgrade && make black to update the formatting?
The files that black
complains about are not part of this pr(?). I'll merge main in there again.