gnp
gnp copied to clipboard
gradient norm penalty
Gradient Regularization in Deep Learning
Table of Contents
- Works Related
- Upgrade [2024.6.15]
- Training using this repo
- Short intro
Works Related
- "Penalizing Gradient Norm for Efficiently Improving Generalization in Deep Learning"[ICML2022], by Yang Zhao, Hao Zhang and Xiuyuan Hu.
- "When Will Gradient Regularization Be Harmful?"[ICML2024], by Yang Zhao, Hao Zhang and Xiuyuan Hu.
Upgrade [2024.6.15]
- JAX Framework Update: Upgraded the training framework to the latest version (JAX 0.4.28).
- New Paper Implementation: Integrated the implementation of our latest research paper into this repository.
- Additional Model Architectures:: Included Swin and CaiT Transformer architectures in the model list.
Training using this repo
-
Environment Setup: This repository is built using the JAX framework. Begin by setting up the Python environment specified in the
requirements.txt
file. -
Configuration: The
config
folder contains all the configuration flags and their default values. You can add custom flags if needed by modifying these files. -
Model Architectures: The
model
folder includes various model architectures such as VGG, ResNet, WideResNet, PyramidNet, ViT, Swin and CaiT. To add custom models, follow the Flax model template and register your model using the_register_model
function in this folder. -
Dataset Pipeline: The
ds_pipeline
folder provides the dataset pipeline, based primarily on the SAM repository. Unlike SAM, this repo uses local ImageNet data instead of tensorflow_datasets. Specify the path to your local dataset folders, ensuring the folder structure is:ImageNet folder └───n01440764 │ │ *.JPEG │ └───n01443537 │ │ *.JPEG ...
-
Optimizers: The
optimizer
folder contains the optimizers, including SGD (Momentum), AdamW and RMSProp. You can add custom optimizers by modifying these files. -
Training Recipes: The
recipe
folder contains.sh
files, each corresponding to a specific model's training script. To run a training script, use the following command:bash wideresnet-cifar.sh
Alternatively, to deploy configurations directly (ensuring the config flag is in the config file), use:
python3 -m gnp.main.main --config=the-train-config-py-file --working_dir=your-output-dir --config.config-anything-else-here
Short intro
1.Overview
Basically, gradient regularization (GR) could be understood as gradient norm penalty, where an additional term regarding the gradient norm $||\nabla_{\theta} L(\theta)||_2$ will be added on top of the empirical loss,
$$\begin{aligned} L(\theta) = L_{\mathcal{S}}(\theta) + \lambda ||\nabla_{\theta} L_{\mathcal{S}}(\theta)||_2 \end{aligned}$$
Gradient norm is considered as a key property that could characterize the flatness of the loss surface. By penalizing the gradient norm, the optimization is encouraged to converge to flatter minima on the loss surface. This results in improved model generalization.
2. Practical Gradient Computation of Gradient Norm
Based on the chain rule, the gradient of the gradient norm is given by:
$$ \nabla_{\theta} L(\theta) = \nabla_{\theta} L_{\mathcal{S}}(\theta) + \lambda \cdot \nabla_{\theta}^2 L_{\mathcal{S}}(\theta) \frac{\nabla_{\theta} L_{\mathcal{S}}(\theta)}{||\nabla_{\theta} L_{\mathcal{S}}(\theta)||} $$
Computing the gradient of this gradient norm term directly involves the full computation of the Hessian matrix. To address this, we use a Taylor expansion to approximate the multiplication between the Hessian matrix and vectors, resulting in:
$$\begin{split} \nabla_{\theta} L(\theta) & = \nabla_{\theta} L_{\mathcal{S}}(\theta) + \lambda \cdot (\frac{\nabla_{\theta}L_\mathcal{S}(\theta +r\frac{\nabla_{\theta}L_{\mathcal{S}}(\theta)}{||\nabla_{\theta}L_{\mathcal{S}}(\theta)||}) - \nabla_{\theta}L_\mathcal{S}(\theta)}{r}) \ & = (1 - \frac{\lambda}{r}) \nabla_{\theta} L_{\mathcal{S}}(\theta) + \frac{\lambda}{r} \cdot \nabla_{\theta}L_\mathcal{S}(\theta +r\frac{\nabla_{\theta}L_{\mathcal{S}}(\theta)}{||\nabla_{\theta}L_{\mathcal{S}}(\theta)||}) \end{split}$$
where $r$ is a small scalar value. So, we need to set two parameters for gradient norm penalty $\lambda$, one for the penalty coefficient and the other one for $r$. For practical convenience, we will further set,
$$\begin{split} \nabla_{\theta} L(\theta) = (1 - \alpha) \nabla_{\theta} L_{\mathcal{S}}(\theta) + \alpha \cdot \nabla_{\theta}L_\mathcal{S}(\theta +r\frac{\nabla_{\theta}L_{\mathcal{S}}(\theta)}{||\nabla_{\theta}L_{\mathcal{S}}(\theta)||}), ~~~\alpha = \frac{\lambda}{r} \end{split}$$
Notably, the SAM algorithm is a special implementation of this scheme where $\alpha$ is always set to 1.
3. Be Careful When using Gradient Regularization with Adaptive Optimizer
GR can lead to serious performance degeneration in the specific scenarios of adaptive optimization.
Error Rate[Cifar10]
Model | Adam | Adam + GR | Adam + GR + Zero-GR-Warmup |
---|---|---|---|
ViT-Ti | 14.82 | 13.92 | 13.61 |
ViT-S | 12.07 | 12.40 |
10.68 |
ViT-B | 10.83 | 12.36 |
9.42 |
With both our empirical observations and theoretical analysis, we find that the biased estimation introduced in GR can induce the instability and divergence in gradient statistics of adaptive optimizers at the initial stage of training, especially with a learning rate warmup technique which originally aims to benefit gradient statistics.
To mitigate this issue, we draw inspirations from the idea of warmup techniques, and propose three GR warmup strategies: $\lambda$-warmup, $r$-warmup and zero-warmup GR. ach of the three strategies can relax the GR effect during warmup course in certain ways to ensure the accuracy of gradient statistics. See paper for details.
End
If you find this helpful, you could cite the papers as
@inproceedings{zhao2022penalizing,
title={Penalizing gradient norm for efficiently improving generalization in deep learning},
author={Zhao, Yang and Zhang, Hao and Hu, Xiuyuan},
booktitle={International Conference on Machine Learning},
pages={26982--26992},
year={2022},
organization={PMLR}
}
@inproceedings{zhaowill,
title={When Will Gradient Regularization Be Harmful?},
author={Zhao, Yang and Zhang, Hao and Hu, Xiuyuan},
booktitle={Forty-first International Conference on Machine Learning}
}