[pure bf16 training] w/ `AnyPrecisionAdamW`
getting pure bf16 training (not mixed) running with AnyPrecisionAdamW also in bf16
I think it should require x8 bytes per param, instead of x18 for mixed precision training - i.e. 1/2 memory usage.
(also included a hack into loading load_from_disk to get saved datasets, but it's unrelated to the actual feature - will remove at the end)
To test checkout this branch:
git clone https://github.com/huggingface/transformers transformers-bf16
cd transformers-bf16
git checkout full-bf16-train
opt-1.3b / bf16-pure
Then prep an empty opt-1.3 model:
cat << EOT > prep-bf16.py
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch
mname = "facebook/opt-1.3b"
config = AutoConfig.from_pretrained(mname)
model = AutoModel.from_config(config, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(mname)
path = "opt-1.3b-bf16"
model.save_pretrained(path)
tokenizer.save_pretrained(path)
EOT
python prep-bf16.py
Train from scratch:
rm -rf save_dir; PYTHONPATH="src" python -m torch.distributed.run \
--nproc_per_node=1 --nnode=1 --node_rank=0 \
--master_addr=127.0.0.1 --master_port=9901 \
examples/pytorch/language-modeling/run_clm.py --bf16 \
--half_precision_backend no_amp --seed 42 --model_name_or_path opt-1.3b-bf16 \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 --optim \
adamw_anyprecision --optim_args \
'use_kahan_summation=true, momentum_dtype=bfloat16, variance_dtype=bfloat16, compensation_buffer_dtype=bfloat16' \
--per_device_train_batch_size 12 --per_device_eval_batch_size 12 \
--gradient_accumulation_steps 1 --do_train --do_eval --logging_steps 10 \
--save_steps 1000 --eval_steps 100 --weight_decay 0.1 --num_train_epochs 1 \
--adam_beta1 0.9 --adam_beta2 0.95 --learning_rate 0.0002 --lr_scheduler_type \
linear --warmup_steps 500 --report_to tensorboard --output_dir save_dir
opt-125m / bf16-pure
If you want to fit into a smaller card, let's do opt-125m
Then prep an empty opt-125m model:
cat << EOT > prep-bf16.py
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch
mname = "facebook/opt-125m"
config = AutoConfig.from_pretrained(mname)
model = AutoModel.from_config(config, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(mname)
path = "opt-125m-bf16"
model.save_pretrained(path)
tokenizer.save_pretrained(path)
EOT
python prep-bf16.py
Train from scratch in pure bf16:
rm -rf save_dir; PYTHONPATH="src" python -m torch.distributed.run \
--nproc_per_node=1 --nnode=1 --node_rank=0 \
--master_addr=127.0.0.1 --master_port=9901 \
examples/pytorch/language-modeling/run_clm.py --bf16 \
--half_precision_backend no_amp --seed 42 --model_name_or_path opt-125m-bf16 \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 --optim \
adamw_anyprecision --optim_args \
'use_kahan_summation=true, momentum_dtype=bfloat16, variance_dtype=bfloat16, compensation_buffer_dtype=bfloat16' \
--per_device_train_batch_size 12 --per_device_eval_batch_size 12 \
--gradient_accumulation_steps 1 --do_train --do_eval --logging_steps 10 \
--save_steps 1000 --eval_steps 100 --weight_decay 0.1 --num_train_epochs 1 \
--adam_beta1 0.9 --adam_beta2 0.95 --learning_rate 0.0002 --lr_scheduler_type \
linear --warmup_steps 500 --report_to tensorboard --output_dir save_dir
opt-125m / fp16-amp
Same for mixed precision fp16 (we want bf16 to give us a similar loss curve when everything else is the same):
cat << EOT > prep-fp16.py
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch
mname = "facebook/opt-125m"
config = AutoConfig.from_pretrained(mname)
model = AutoModel.from_config(config, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(mname)
path = "opt-125m-fp16"
model.save_pretrained(path)
tokenizer.save_pretrained(path)
EOT
python prep-fp16.py
rm -rf save_dir; PYTHONPATH="src" python -m torch.distributed.run \
--nproc_per_node=1 --nnode=1 --node_rank=0 \
--master_addr=127.0.0.1 --master_port=9901 \
examples/pytorch/language-modeling/run_clm.py --ff16 \
--seed 42 --model_name_or_path opt-125m-fp16 \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
--per_device_train_batch_size 12 --per_device_eval_batch_size 12 \
--gradient_accumulation_steps 1 --do_train --do_eval --logging_steps 10 \
--save_steps 1000 --eval_steps 100 --weight_decay 0.1 --num_train_epochs 1 \
--adam_beta1 0.9 --adam_beta2 0.95 --learning_rate 0.0002 --lr_scheduler_type \
linear --warmup_steps 500 --report_to tensorboard --output_dir save_dir
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.