maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

fix dtype bug in adam_pax

Open ZhiyuLi-goog opened this issue 11 months ago • 0 comments

[Bug] adam_pax has reuse donated buffer warning

Reproduced with weight_dtype=bfloat16

python3 MaxText/train.py MaxText/configs/base.yml run_name=run steps=10 weight_dtype=bfloat16 opt_type=adam_pax dataset_type=synthetic enable_checkpointing=false
/home/lizhiyu/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:914: 
UserWarning: Some donated buffers were not usable: ShapedArray(bfloat16[512]),
ShapedArray(bfloat16[512,16,7168]), ShapedArray(bfloat16[512,16,7168]), 
ShapedArray(bfloat16[7168,16,512]), ShapedArray(bfloat16[512,16]), ...
  • root cause bias_corrected_decay forced optimizer state convert to float32 despite initialization in bfloat16. The data type change of optimizer state broke buffer donation. It wasn't an issue for pax gpt3 since all variables are float32.

  • solution Added a new conversion while keeping bias_corrected_decay calculated in float32 following the source code in optax

  • [x] Verified the warning disappeared after the change.

ZhiyuLi-goog avatar Mar 09 '24 00:03 ZhiyuLi-goog