composer
composer copied to clipboard
Refactor stochastic depth to generalize to some novel models
🚀 Feature Request
Use torch.fx to automatically detect residual blocks and residual connections in a model, then manipulate these components to perform stochastic depth.
Motivation
Right now, stochastic depth replaces a hard coded module (i.e. composer.models.resnets.Bottleneck) with a manually defined stochastic version of the module. The stochastic module is designed to randomly skip the main computation during training and to multiply the residual connection by the probability of skipping.
In order to avoid manually specifying the module to replace and a respective stochastic module, the residual blocks and residual connections need to be automatically identified and manipulated. torch.fx may provide the tools to be able to do this.
Automatic identification and manipulation would allow stochastic depth to be applied to several models without hard coded specification:
- CIFAR ResNets
- EfficientNet
- GPT-2
- BERT
- DeepLabv3
Implementation
Vague idea for how to do this:
- Identify two-argument
add
operations in a model architecture - Trace the arguments to a single point in the computation graph -> the start of the residual block
- Add a conditional statement based on a Bernoulli variable after the single point in the computation graph
- Scale the residual connection by the probability of the entering the conditional statement
Caveat: I don't know if conditional statements can be added with torch.fx.
Discussion
I don't expect this to work for every architecture, but it should be at least more generalizable than the current implementation. Alternatively, we can update stochastic depth with every new target architecture, but this is not sustainable.
@hanlint Do you know of anyone on research eng I could talk to about this?
Looking at torch.fx
, they do support graph re-writes and even subgraph pattern matching (https://pytorch.org/docs/stable/fx.html#torch.fx.replace_pattern). Looks like they return a torch.nn.Module
, but we'll need to check its compatibility with other surgery algorithms that we have.
This is probably lower priority for now, given our push towards usability.
Closing. Tracking elsewhere as low pri