opacus
opacus copied to clipboard
Opacus: I made a model-agnostic callback for PyTorch Lightning
🚀 Feature
I could get Opacus to work with PyTorch Lightning (pl
) using a pl.Callback
. Note that the callback is model-agnostic, and the model's pl.LightningModule
class does not have anything related to Opacus.
Motivation
We need an easy way for PyTorch Lightning users to use Opacus without them having to refactor their LightningModule
classes. See below.
Pitch
We need something as follows: (I could actually implement this for real but it only works only for models with one optimizer)
import pytorch_lightning as pl
from opacus import OpacusCallback
from pl_bolts.models import LitMNIST
from pl_bolts.datamodules import MNISTDataModule
trainer = pl.Trainer(
callbacks=[
OpacusCallback(...), # all that is needed for DP-training
],
)
trainer.fit(model=LitMNIST(...), datamodule=MNISTDataModule(...))
Additional context
In my version for OpacusCallback
, all I do is call .make_private
in the on_train_epoch_start
hook:
# --- pseudo code --- #
def on_train_epoch_start(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
) -> None:
optimizers: ty.List[Optimizer] = []
# for loop: begin
for i in range(len(trainer.optimizers)):
optimizer = trainer.optimizers[i]
# this works
_, dp_optimizer, _ = self.privacy_engine.make_private( # or make_private_with_epsilon
module=pl_module,
optimizer=optimizer,
data_loader=trainer._data_connector._train_dataloader_source.dataloader(),
noise_multiplier=self.noise_multiplier,
max_grad_norm=self.max_grad_norm,
clipping=self.clipping, # "flat" or "per_layer" or "adaptive"
poisson_sampling=self.poisson_sampling,
)
optimizers.append(dp_optimizer)
### this will fail
# if not hasattr(pl_module, "autograd_grad_sample_hooks"):
# pl_module = GradSampleModule(pl_module)
# dp_optimizer = privacy_engine._prepare_optimizer(
# optimizer,
# noise_multiplier=self.noise_multiplier,
# max_grad_norm=self.max_grad_norm,
# expected_batch_size=expected_batch_size,
# )
# for loop: end
trainer.optimizers = optimizers
What's cool is that this is an EarlyStopping
callback, so it will stop training when enough privacy budget has been spent.