opacus
opacus copied to clipboard
[RFC] Functorch to automatically compute per-sample gradients
TL;DR We can use Functorch to compute per sample gradients automatically.
This PR proposes to use Functorch to compute per sample gradients. The "meat" is in grad_sample/functorch.py. It basically does a mini-"per sample" forward/backward on the layer at hand, with the tricks that outputs.backward(backprops) is equivalent to (outputs * backprops).sum().backward().
Open questions
- How do we validate the modules? It used to be clear that any module which grad_sampler was not implemented was not OK but now it's different. Should we move from a model where everything not allowed was forbidden, to a model where everything not forbidden is allowed.
- Do we want to provide clean support for activation/de-activation of GRAD_SAMPLERS or is it ok to just del GRAD_SAMPLERS[module]?