grokking
grokking copied to clipboard
Re-implementation of 'Grokking: Generalization beyond overfitting on small algorithmic datasets'
Re-implementation of the paper 'Grokking: Generalization beyond overfitting on small algorithmic datasets'
Paper
Original paper can be found here
Datasets
All datasets from the original paper's appendix are supported.
Hyperparameters
The default hyperparameters are from the paper, but can be adjusted via the command line when running train.py
Running experiments
To run with default settings, simply run python train.py
.
The first time you train on any dataset you have to specify --force_data
.
Arguments:
optimizer args
- "--lr", type=float, default=1e-3
- "--weight_decay", type=float, default=1
- "--beta1", type=float, default=0.9
- "--beta2", type=float, default=0.98
model args
- "--num_heads", type=int, default=4
- "--layers", type=int, default=2
- "--width", type=int, default=128
data args
- "--data_name", type=str, default="perm", choices=[
- "perm_xy", # permutation composition x * y
- "perm_xyx1", # permutation composition x * y * x^-1
- "perm_xyx", # permutation composition x * y * x
- "plus", # x + y
- "minus", # x - y
- "div", # x / y
- "div_odd", # x / y if y is odd else x - y
- "x2y2", # x^2 + y^2
- "x2xyy2", # x^2 + y^2 + xy
- "x2xyy2x", # x^2 + y^2 + xy + x
- "x3xy", # x^3 + y
- "x3xy2y" # x^3 + xy^2 + y ]
- "--num_elements", type=int, default=5 (choose 5 for permutation data, 97 for arithmetic data)
- "--data_dir", type=str, default="./data"
- "--force_data", action="store_true", help="Whether to force dataset creation."
training args
- "--batch_size", type=int, default=512
- "--steps", type=int, default=10**5
- "--train_ratio", type=float, default=0.5
- "--seed", type=int, default=42
- "--verbose", action="store_true"
- "--log_freq", type=int, default=10
- "--num_workers", type=int, default=4
- "--disable_logging", action="store_true", help="Whether to use wandb logging"
- "--checkpoints", type=int, default=None, nargs="*", help="List of number of steps after which to save model."