maxtext
maxtext copied to clipboard
fix dtype bug in adam_pax
[Bug] adam_pax has reuse donated buffer warning
Reproduced with weight_dtype=bfloat16
python3 MaxText/train.py MaxText/configs/base.yml run_name=run steps=10 weight_dtype=bfloat16 opt_type=adam_pax dataset_type=synthetic enable_checkpointing=false
/home/lizhiyu/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:914:
UserWarning: Some donated buffers were not usable: ShapedArray(bfloat16[512]),
ShapedArray(bfloat16[512,16,7168]), ShapedArray(bfloat16[512,16,7168]),
ShapedArray(bfloat16[7168,16,512]), ShapedArray(bfloat16[512,16]), ...
-
root cause
bias_corrected_decay
forced optimizer state convert to float32 despite initialization in bfloat16. The data type change of optimizer state broke buffer donation. It wasn't an issue for pax gpt3 since all variables are float32. -
solution Added a new conversion while keeping
bias_corrected_decay
calculated in float32 following the source code in optax -
[x] Verified the warning disappeared after the change.