opacus
opacus copied to clipboard
Augmentation Multiplicity
🚀 Feature
Augmentation multiplicity as explained in Section 3.1 of https://arxiv.org/abs/2204.13650. Before clipping gradients, we may average gradients over different augmentations.
Motivation
Paper shows improved results compared to no augmentation.
If my understanding is correct, currently this could be implemented by averaging over different augmentations in the loss function, but I am not 100% sure this is correct.
@timudk Thank you for bringing this! Do you have ideas how this can be implemented with Opacus? For example, we may add an option that changes clipping behaviour. Do we need to change the privacy accounting part?
Just a note that we have implemented this in an internal codebase by making various tweaks to Opacus.
This involves changing the layout of the clipping part, but the privacy accounting shouldn't change since each 'user' still contributes one (averaged) gradient.
@lxuechen any plans to make the codebase public anytime soon?
@timudk I think you can use Xuechen's modified privacy-engine for this? https://github.com/lxuechen/private-transformers/blob/main/private_transformers/privacy_engine.py
I'm not sure but it looks similar to my internal implementation of the augmentation multiplicity trick. It does require the somewhat significant API modification of requiring the optimizer itself to take in the 1-D loss tensor as input, but I'm not sure how to get around this one way or another.
hi folks, thanks for following up.
that codebase actually does not yet have the aug. multiplicity trick. we're working on a project that involves this and will release the code when the project is done.
implementing the trick is actually not that hard. you could for instance duplicate indices in the sampler and perform gradient aggregation after the per-sample grads are instantiated.
Just wanted to point out that this is now possible with Functorch (grad_sample_mode="no_op"
). It should be straightforward to adapt e.g. the example on CIFAR-10 to handle data augmentations.