grokking copied to clipboard
Re-implementation of 'Grokking: Generalization beyond overfitting on small algorithmic datasets'
Re-implementation of the paper 'Grokking: Generalization beyond overfitting on small algorithmic datasets'
Original paper can be found here
All datasets from the original paper's appendix are supported.
The default hyperparameters are from the paper, but can be adjusted via the command line when running
Running experiments
To run with default settings, simply run python
The first time you train on any dataset you have to specify --force_data
optimizer args
- "--lr", type=float, default=1e-3
- "--weight_decay", type=float, default=1
- "--beta1", type=float, default=0.9
- "--beta2", type=float, default=0.98
model args
- "--num_heads", type=int, default=4
- "--layers", type=int, default=2
- "--width", type=int, default=128
data args
- "--data_name", type=str, default="perm", choices=[
- "perm_xy", # permutation composition x * y
- "perm_xyx1", # permutation composition x * y * x^-1
- "perm_xyx", # permutation composition x * y * x
- "plus", # x + y
- "minus", # x - y
- "div", # x / y
- "div_odd", # x / y if y is odd else x - y
- "x2y2", # x^2 + y^2
- "x2xyy2", # x^2 + y^2 + xy
- "x2xyy2x", # x^2 + y^2 + xy + x
- "x3xy", # x^3 + y
- "x3xy2y" # x^3 + xy^2 + y ]
- "--num_elements", type=int, default=5 (choose 5 for permutation data, 97 for arithmetic data)
- "--data_dir", type=str, default="./data"
- "--force_data", action="store_true", help="Whether to force dataset creation."
training args
- "--batch_size", type=int, default=512
- "--steps", type=int, default=10**5
- "--train_ratio", type=float, default=0.5
- "--seed", type=int, default=42
- "--verbose", action="store_true"
- "--log_freq", type=int, default=10
- "--num_workers", type=int, default=4
- "--disable_logging", action="store_true", help="Whether to use wandb logging"
- "--checkpoints", type=int, default=None, nargs="*", help="List of number of steps after which to save model."