opacus icon indicating copy to clipboard operation
opacus copied to clipboard

[RFC] Functorch to automatically compute per-sample gradients

Open alexandresablayrolles opened this issue 3 years ago • 0 comments

TL;DR We can use Functorch to compute per sample gradients automatically.

This PR proposes to use Functorch to compute per sample gradients. The "meat" is in grad_sample/functorch.py. It basically does a mini-"per sample" forward/backward on the layer at hand, with the tricks that outputs.backward(backprops) is equivalent to (outputs * backprops).sum().backward().

Open questions

  1. How do we validate the modules? It used to be clear that any module which grad_sampler was not implemented was not OK but now it's different. Should we move from a model where everything not allowed was forbidden, to a model where everything not forbidden is allowed.
  2. Do we want to provide clean support for activation/de-activation of GRAD_SAMPLERS or is it ok to just del GRAD_SAMPLERS[module]?

alexandresablayrolles avatar Jul 04 '22 15:07 alexandresablayrolles