overparam
overparam copied to clipboard
Overparam layers
PyTorch linear over-parameterization layers with automatic graph reduction.
Official codebase used in:
The Low-Rank Simplicity Bias in Deep Networks
Minyoung Huh Hossein Mobahi Richard Zhang Brian Cheung Pulkit Agrawal Phillip Isola
MIT CSAIL Google Research Adobe Research MIT BCS
TMLR 2023 (arXiv 2021).
[project page] | [paper] | [arXiv]
1. Installation
Developed on
- Python 3.7 :snake:
- PyTorch 1.7 :fire:
> git clone https://github.com/minyoungg/overparam
> cd overparam
> pip install .
2. Usage
The layers work exactly the same as any torch.nn layers.
Getting started
(1a) OverparamLinear layer (equivalence: nn.Linear)
from overparam import OverparamLinear
layer = OverparamLinear(16, 32, width=1, depth=2)
x = torch.randn(1, 16)
(1b) OverparamConv2d layer (equivalence: nn.Conv2d)
from overparam import OverparamConv2d
import numpy as np
We can construct 3 Conv2d layers with kernel dimensions of 5x5, 3x3, 1x1
# Same padding
padding = max((np.sum(kernel_sizes) - len(kernel_sizes) + 1) // 2, 0)
layer = OverparamConv2d(2, 4, kernel_sizes=[5, 3, 1], padding, depth=len(kernel_sizes))
# Get the effective kernel size
print(layer.kernel_size)
When kernel_sizes is an integer, all proceeding layers are assumed to have kernel size of 1x1.
(2) Forward computation
# Forward pass (expanded form)
layer.train()
y = layer(x)
When calling eval() the model will automatically reduce the computation graph to its effective single-layer counterpart.
Forward pass in eval mode will use the effective weights instead.
# Forward pass (collapsed form) [automatic]
layer.eval()
y = layer(x)
You can access the effective weights as follows:
print(layer.weight)
print(layer.bias)
(3) Automatic conversion
import torchvision.models as models
from overparam.utils import overparameterize
model = models.alexnet() # Replace this with YOUR_PYTORCH_MODEL()
model = overparameterize(model, depth=2)
(4) Batch-norm and Residual connections
We also provide support for batch-norm and linear residual connections.
- batch-normalization (pseudo-linera layer: linear during
evalmode)
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
batch_norm=True)
- residual-connection
# every 2 layers, a residual connection is added
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
residual=True, residual_intervals=2)
- multiple residual connection
# every modulo [1, 2, 3] layers, a residual connection is added
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
residual=True, residual_intervals=[1, 2, 3])
- batch-norm and residual connection
# mimics `BasicBlock` in ResNets
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
batch_norm=True, residual=True, residual_intervals=2)
3. Cite
@article{huh2023simplicitybias,
title={The Low-Rank Simplicity Bias in Deep Networks},
author={Minyoung Huh and Hossein Mobahi and Richard Zhang and Brian Cheung and Pulkit Agrawal and Phillip Isola},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=bCiNWDmlY2},
}