fairseq
fairseq copied to clipboard
Unable to run fairseq_cli.eval_lm script for MoE with a sharded checkpoint
🐛 Bug
I trained a MoE model on multiple nodes and am now attempting to calculate eval losses on a new dataset using this model. I am using the example script from the MoE branch to do this. I originally trained the model on 64 GPUs (4 nodes), but now hope to calculate eval losses using only 4 or 16 GPUs on a single node.
My checkpoint directory has files that look like this:
$ ls $MODEL_DIR/checkpoint_last*
checkpoint_last-rank-0-shard0.pt checkpoint_last-rank-27-shard27.pt checkpoint_last-rank-44-shard44.pt checkpoint_last-rank-61-shard61.pt checkpoint_last-shared-shard20.pt checkpoint_last-shared-shard38.pt checkpoint_last-shared-shard55.pt
checkpoint_last-rank-10-shard10.pt checkpoint_last-rank-28-shard28.pt checkpoint_last-rank-45-shard45.pt checkpoint_last-rank-62-shard62.pt checkpoint_last-shared-shard21.pt checkpoint_last-shared-shard39.pt checkpoint_last-shared-shard56.pt
checkpoint_last-rank-11-shard11.pt checkpoint_last-rank-29-shard29.pt checkpoint_last-rank-46-shard46.pt checkpoint_last-rank-63-shard63.pt checkpoint_last-shared-shard22.pt checkpoint_last-shared-shard3.pt checkpoint_last-shared-shard57.pt
checkpoint_last-rank-12-shard12.pt checkpoint_last-rank-2-shard2.pt checkpoint_last-rank-47-shard47.pt checkpoint_last-rank-6-shard6.pt checkpoint_last-shared-shard23.pt checkpoint_last-shared-shard40.pt checkpoint_last-shared-shard58.pt
checkpoint_last-rank-13-shard13.pt checkpoint_last-rank-30-shard30.pt checkpoint_last-rank-48-shard48.pt checkpoint_last-rank-7-shard7.pt checkpoint_last-shared-shard24.pt checkpoint_last-shared-shard41.pt checkpoint_last-shared-shard59.pt
checkpoint_last-rank-14-shard14.pt checkpoint_last-rank-31-shard31.pt checkpoint_last-rank-49-shard49.pt checkpoint_last-rank-8-shard8.pt checkpoint_last-shared-shard25.pt checkpoint_last-shared-shard42.pt checkpoint_last-shared-shard5.pt
checkpoint_last-rank-15-shard15.pt checkpoint_last-rank-32-shard32.pt checkpoint_last-rank-4-shard4.pt checkpoint_last-rank-9-shard9.pt checkpoint_last-shared-shard26.pt checkpoint_last-shared-shard43.pt checkpoint_last-shared-shard60.pt
checkpoint_last-rank-16-shard16.pt checkpoint_last-rank-33-shard33.pt checkpoint_last-rank-50-shard50.pt checkpoint_last-shared-shard0.pt checkpoint_last-shared-shard27.pt checkpoint_last-shared-shard44.pt checkpoint_last-shared-shard61.pt
checkpoint_last-rank-17-shard17.pt checkpoint_last-rank-34-shard34.pt checkpoint_last-rank-51-shard51.pt checkpoint_last-shared-shard10.pt checkpoint_last-shared-shard28.pt checkpoint_last-shared-shard45.pt checkpoint_last-shared-shard62.pt
checkpoint_last-rank-18-shard18.pt checkpoint_last-rank-35-shard35.pt checkpoint_last-rank-52-shard52.pt checkpoint_last-shared-shard11.pt checkpoint_last-shared-shard29.pt checkpoint_last-shared-shard46.pt checkpoint_last-shared-shard63.pt
checkpoint_last-rank-19-shard19.pt checkpoint_last-rank-36-shard36.pt checkpoint_last-rank-53-shard53.pt checkpoint_last-shared-shard12.pt checkpoint_last-shared-shard2.pt checkpoint_last-shared-shard47.pt checkpoint_last-shared-shard6.pt
checkpoint_last-rank-1-shard1.pt checkpoint_last-rank-37-shard37.pt checkpoint_last-rank-54-shard54.pt checkpoint_last-shared-shard13.pt checkpoint_last-shared-shard30.pt checkpoint_last-shared-shard48.pt checkpoint_last-shared-shard7.pt
checkpoint_last-rank-20-shard20.pt checkpoint_last-rank-38-shard38.pt checkpoint_last-rank-55-shard55.pt checkpoint_last-shared-shard14.pt checkpoint_last-shared-shard31.pt checkpoint_last-shared-shard49.pt checkpoint_last-shared-shard8.pt
checkpoint_last-rank-21-shard21.pt checkpoint_last-rank-39-shard39.pt checkpoint_last-rank-56-shard56.pt checkpoint_last-shared-shard15.pt checkpoint_last-shared-shard32.pt checkpoint_last-shared-shard4.pt checkpoint_last-shared-shard9.pt
checkpoint_last-rank-22-shard22.pt checkpoint_last-rank-3-shard3.pt checkpoint_last-rank-57-shard57.pt checkpoint_last-shared-shard16.pt checkpoint_last-shared-shard33.pt checkpoint_last-shared-shard50.pt
checkpoint_last-rank-23-shard23.pt checkpoint_last-rank-40-shard40.pt checkpoint_last-rank-58-shard58.pt checkpoint_last-shared-shard17.pt checkpoint_last-shared-shard34.pt checkpoint_last-shared-shard51.pt
checkpoint_last-rank-24-shard24.pt checkpoint_last-rank-41-shard41.pt checkpoint_last-rank-59-shard59.pt checkpoint_last-shared-shard18.pt checkpoint_last-shared-shard35.pt checkpoint_last-shared-shard52.pt
checkpoint_last-rank-25-shard25.pt checkpoint_last-rank-42-shard42.pt checkpoint_last-rank-5-shard5.pt checkpoint_last-shared-shard19.pt checkpoint_last-shared-shard36.pt checkpoint_last-shared-shard53.pt
checkpoint_last-rank-26-shard26.pt checkpoint_last-rank-43-shard43.pt checkpoint_last-rank-60-shard60.pt checkpoint_last-shared-shard1.pt checkpoint_last-shared-shard37.pt checkpoint_last-shared-shard54.pt
To Reproduce
Steps to reproduce the behavior:
- First, I attempted to run a command that looks like this (basically the same as the example):
export PYTHONPATH="${PYTHONPATH:+${PYTHONPATH}:}/home/$USER/fairseq"
NUM_GPUS=`ls /dev/ | grep -E 'nvidia[0-9]+' | wc -l`
CHECKPOINT_TO_PROCESS='checkpoint_last'
TOKENS_PER_SAMPLE=1024
BATCH_SIZE=1
MODEL_CAPACITY=32 # based on train script = 2 * (local_batch_size)/(global_num_experts) = 2 * (8*1024)/512 = 32
MOE_EVAL_CAPACITY_TOKEN_FRACTION=`python3 -c "print($MODEL_CAPACITY/($BATCH_SIZE * $TOKENS_PER_SAMPLE))"`
DATA_PATH=$FINAL_PATH/data/
MODEL_DIR="$FINAL_PATH/fairseq/"
MODEL_PATH="$MODEL_DIR/$CHECKPOINT_TO_PROCESS.pt"
set -ux
python -m fairseq_cli.eval_lm \
$DATA_PATH \
--path $MODEL_PATH \
--gen-subset test_shifted \
--sample-break-mode none \
--tokens-per-sample $TOKENS_PER_SAMPLE \
--batch-size 1 \
--fp16 \
--output-word-probs \
--is-moe \
--distributed-world-size $NUM_GPUS \
--model-overrides "{'world_size': $NUM_GPUS, 'moe_eval_capacity_token_fraction': $MOE_EVAL_CAPACITY_TOKEN_FRACTION}"
This script crashes with the error:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/home/rohitd/.venv/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/gandiva/rohitd/fairseq/fairseq/distributed/utils.py", line 335, in distributed_main
main(cfg, **kwargs)
File "/home/rohitd/fairseq/fairseq_cli/eval_lm.py", line 391, in main
is_moe=is_moe or is_base_moe,
File "/home/rohitd/fairseq/fairseq/checkpoint_utils.py", line 450, in load_model_ensemble_and_task
raise IOError("Model file not found: {}".format(filename))
OSError: Model file not found: /home/rohitd/ALL_FILES/fairseq/checkpoint_last-rank-1.pt
It looks like the script is not picking up the fact that files have a -shard-{number}
in their filename. So, I tried looking through the arguments and found this flag: --use-sharded-state
. I tried adding this flag and rerunning the code, but this also resulted in the same error:
-- Process 2 terminated with the following error:
Traceback (most recent call last):
File "/home/rohitd/.venv/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/rohitd/fairseq/fairseq/distributed/utils.py", line 335, in distributed_main
main(cfg, **kwargs)
File "/home/rohitd/fairseq/fairseq_cli/eval_lm.py", line 391, in main
is_moe=is_moe or is_base_moe,
File "/home/rohitd/fairseq/fairseq/checkpoint_utils.py", line 450, in load_model_ensemble_and_task
raise IOError("Model file not found: {}".format(filename))
OSError: Model file not found: /home/rohitd/ALL_FILES/fairseq/checkpoint_last-rank-2.pt
Then, I tried making a list of arguments that might potentially fix this. I found:
-
--num-shards
and--shard-id
- but both of these are arguments for thedataloader
, so they are unlikely to help checkpoint loading. -
--ddp-backend
- setting this tofully_sharded
might help - when training I set this flag, so perhaps I need to set it for eval as well> -
--zero-sharding
- don't think I used ZeRO, so don't see why we should use this.
I tried using just the --ddp-backend fully_sharded
argument without setting --use-sharded-state
and same error:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/rohitd/.venv/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/rohitd/fairseq/fairseq/distributed/utils.py", line 335, in distributed_main
main(cfg, **kwargs)
File "/home/gandiva/rohitd/fairseq/fairseq_cli/eval_lm.py", line 391, in main
is_moe=is_moe or is_base_moe,
File "/home/rohitd/fairseq/fairseq/checkpoint_utils.py", line 450, in load_model_ensemble_and_task
raise IOError("Model file not found: {}".format(filename))
OSError: Model file not found: /home/rohitd/ALL_FILES/fairseq/checkpoint_last-rank-0.pt
Then I tried setting both --ddp-backend fully_sharded
and --use-sharded-state
together, and it's still the same:
-- Process 3 terminated with the following error:
Traceback (most recent call last):
File "/home/rohitd/.venv/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/home/rohitd/fairseq/fairseq/distributed/utils.py", line 335, in distributed_main
main(cfg, **kwargs)
File "/home/rohitd/fairseq/fairseq_cli/eval_lm.py", line 391, in main
is_moe=is_moe or is_base_moe,
File "/home/rohitd/fairseq/fairseq/checkpoint_utils.py", line 450, in load_model_ensemble_and_task
raise IOError("Model file not found: {}".format(filename))
OSError: Model file not found: /home/ALL_FILES/fairseq/checkpoint_last-rank-3.pt
Expected behavior
The fairseq_cli.eval_lm
script loads the MoE model and prints the eval loss.
Environment
- fairseq Version: git+https://github.com/pytorch/fairseq@ebea0072062e2e8f4563644f27546df355357f5e#egg=fairseq
- PyTorch Version: 1.12.0
- OS: Ubuntu 18.04.6 LTS
-
How you installed fairseq (
pip
, source):git clone
followed bygit checkout moe
followed bypip install --editable .
- Build command you used (if compiling from source):
- Python version: 3.7.5
- CUDA/cuDNN version: cuda_11.6.r11.6/compiler.30978841_0
- GPU models and configuration: 4 * NVIDIA v100
You could circumvent this by just running eval using the train loop itself - just specify the checkpoint and then specify your new dataset as a CL arg to train.py
- this works, but requires as many GPUs as were used during training - which, in my case was 64 GPUs - it would be nice if there was a way to use the fairseq_cli.eval_lm
script instead, to be able to run it on one node.
Also, another issue that I noticed with using train.py
is the calculating eval
PPL is extremely slow. It took 3 hours on 64 GPUs to calculate eval loss for 220M tokens. In contrast, one train batch takes approximately two seconds, which means that training for 220M tokens would take 14 minutes - so eval is 12x slower. This is what the logs look like:
2022-10-18 07:52:22 | INFO | fairseq.modules.fused_bias_gelu | Done with compiling and loading fused kernels.
2022-10-18 07:56:38 | INFO | train_inner | {"epoch": 2, "update": 1.367, "loss": "3.734", "moe_gate_loss": "10.9212", "overflow_expert1": "1.683", "overflow_expert2": "48.656", "entropy_gating": "5.45", "expert1_balance_top": "34.515", "expert1_balance_bottom": "9.431", "unused_expert1_count": "0.496", "expert2_balance_top": "65.803", "expert2_balance_bottom": "1.361", "unused_expert2_count": "42.408", "all_to_all_cpu_time_ms": "0", "all_to_all_cuda_time_ms": "0", "inner_loss": "3.576", "ppl": "11.93", "wps": "228.1", "ups": "0", "wpb": "524288", "bsz": "512", "num_updates": "572001", "lr": "1.09659e-07", "gnorm": "0.179", "loss_scale": "8", "train_wall": "1026", "cuda_gb_allocated": "9.5", "cuda_gb_reserved": "12.1", "cuda_gb_free": "22.3", "wall": "0"}
2022-10-18 07:56:38 | INFO | fairseq_cli.train | begin validation on "test_shifted" subset on rank 0
2022-10-18 07:56:38 | INFO | fairseq_cli.train | got valid iterator on "test_shifted" subset on rank 0
2022-10-18 07:56:38 | INFO | fairseq_cli.train | Begin looping over validation "test_shifted" subset with length "419"
2022-10-18 07:56:38 | INFO | fairseq_cli.train | Inside the handler
2022-10-18 10:55:16 | INFO | test_shifted | {"epoch": 2, "test_shifted_loss": "3.699", "test_shifted_moe_gate_loss": "10.1184", "test_shifted_overflow_expert1": "0", "test_shifted_overflow_expert2": "0", "test_shifted_entropy_gating": "5.451", "test_shifted_expert1_balance_top": "37.697", "test_shifted_expert1_balance_bottom": "7.838", "test_shifted_unused_expert1_count": "0.894", "test_shifted_expert2_balance_top": "64.978", "test_shifted_expert2_balance_bottom": "1.549", "test_shifted_unused_expert2_count": "36.374", "test_shifted_all_to_all_cpu_time_ms": "0", "test_shifted_all_to_all_cuda_time_ms": "0", "test_shifted_inner_loss": "3.553", "test_shifted_ppl": "11.73", "test_shifted_wps": "20451.1", "test_shifted_wpb": "523071", "test_shifted_bsz": "510.8", "test_shifted_num_updates": "572001"}
2022-10-18 10:55:18 | INFO | train_inner | {"epoch": 2, "update": 1.367, "loss": "3.823", "moe_gate_loss": "10.8059", "overflow_expert1": "1.684", "overflow_expert2": "48.231", "entropy_gating": "5.453", "expert1_balance_top": "34.583", "expert1_balance_bottom": "9.416", "unused_expert1_count": "0.497", "expert2_balance_top": "65.621", "expert2_balance_bottom": "1.374", "unused_expert2_count": "42.297", "all_to_all_cpu_time_ms": "0", "all_to_all_cuda_time_ms": "0", "inner_loss": "3.667", "ppl": "12.7", "wps": "48.9", "ups": "0", "wpb": "524288", "bsz": "512", "num_updates": "572002", "lr": "1.09121e-07", "gnorm": "0.176", "loss_scale": "8", "train_wall": "2", "cuda_gb_allocated": "17.6", "cuda_gb_reserved": "28.7", "cuda_gb_free": "14.1", "wall": "0"}
and this is the script I used:
export PYTHONPATH="${PYTHONPATH:+${PYTHONPATH}:}/home/t-rohitd/fairseq"
EVAL_AT_STEP_NUMBER=checkpoint_2_572000.pt
NODES=4
GPUS_PER_NODE=16
NUM_EXPERTS=512
TOKENS_PER_SAMPLE=1024
BATCH_SIZE=8 # batch size per GPU
GRAD_ACC=1 # gradient accumulation
# launch the job (adjust port and --cpu-bind if needed)
DISTRIBUTED_PORT=12345
srun -o ${LOCAL_CKPT_DIR}/eval_log_${EVAL_AT_STEP_NUMBER}.txt --gpus-per-node ${GPUS_PER_NODE} --ntasks-per-node ${GPUS_PER_NODE} --cpus-per-task 6 --nodes $NODES --mem-per-gpu 80G \
python fairseq_cli/train.py \
--train-subset train_shifted --valid-subset test_shifted \
--distributed-port ${DISTRIBUTED_PORT} \
--save-dir ${LOCAL_CKPT_DIR} --save-interval-updates 200 --save-async \
--load-checkpoint-on-all-dp-ranks --checkpoint-shard-count ${TOTAL_GPUS} \
--ddp-backend fully_sharded --memory-efficient-fp16 --checkpoint-activations \
--task language_modeling ${SRC_DIR} --tokens-per-sample ${TOKENS_PER_SAMPLE} \
--arch transformer_lm_gpt2_small --share-decoder-input-output-embed \
--decoder-layers 24 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 \
--decoder-attention-heads 16 \
--moe-expert-count ${NUM_EXPERTS} --moe-freq 2 \
--moe-gating-use-fp32 --moe-second-expert-policy all \
--moe-normalize-expert-grad sqrt_world_size \
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr ${LR} --lr-scheduler linear --warmup-updates ${WARMUP_UPDATES} --total-num-update ${TOTAL_UPDATES} \
--dropout 0.1 --attention-dropout 0.1 \
--batch-size ${BATCH_SIZE} --update-freq ${GRAD_ACC} --required-batch-size-multiple ${REQUIRED_BATCH_SIZE_MULTIPLE} \
--max-update ${TOTAL_UPDATES} --log-format json --log-interval 1 --restore-file ${REMOTE_CKPT_DIR}/${EVAL_AT_STEP_NUMBER} \
--validate-interval-updates 100 --num-workers-valid 1 --max-valid-steps 418
Managed to solve this issue - in case anyone else stumbles on this problem in the future, just thought I'd post what I did here. Let us say we want to use the checkpoint_last
for calculating eval PPLs. There are a few things that need to be changed:
- In my checkpoint, when training with a
world_size
of64
there were64
checkpoint files in the formatcheckpoint_last-shared-shard{i}.pt
wherei
ranged from0
to63
. However, all of them contain the same information. Theeval_lm
script expects exactly one file with the namecheckpoint_last-shared.pt
(with noshard{i}
in the filename), - Secondly, rename all files
checkpoint_last-rank-{i}-shard{i}.pt
to remove the-shard{i}
part from it.
I did this by creating a new temp directory and creating symlinks. Running this script worked perfectly: (
export PYTHONPATH="${PYTHONPATH:+${PYTHONPATH}:}/home/t-rohitd/fairseq"
NUM_GPUS=`ls /dev/ | grep -E 'nvidia[0-9]+' | wc -l`
CHECKPOINT_TO_PROCESS='checkpoint_last'
TOKENS_PER_SAMPLE=1024
BATCH_SIZE=1
MODEL_CAPACITY=32 # based on train script = 2 * (local_batch_size)/(global_num_experts) = 2 * (8*1024)/512 = 32
MOE_EVAL_CAPACITY_TOKEN_FRACTION=`python3 -c "print($MODEL_CAPACITY/($BATCH_SIZE * $TOKENS_PER_SAMPLE))"`
DATA_PATH=path/to/my/dataset
# create temporary model checkpoint directory and create symlinks
RANK_PATHS=`find path/to/ckpt/dir/ -name $CHECKPOINT_TO_PROCESS-rank-*.pt`
TEMP_FOLDER=`mktemp -d`
pushd $TEMP_FOLDER
for m in $RANK_PATHS;
do
filename=`echo $m | rev | cut -d '/' -f1 | rev | sed 's/-shard[0-9]*//g'` # extract only filename from full path
ln -s $m ./$filename
done;
SHARED_PATH=`find path/to/ckpt/dir/ -name $CHECKPOINT_TO_PROCESS-shared-shard0.pt`
filename=`echo $SHARED_PATH | rev | cut -d '/' -f1 | rev | sed 's/-shard[0-9]*//g'`
ln -s $SHARED_PATH ./$filename
popd
set -ux
python -m fairseq_cli.eval_lm \
$DATA_PATH \
--ddp-backend fully_sharded \
--path $TEMP_FOLDER/$CHECKPOINT_TO_PROCESS.pt \
--gen-subset test_shifted \
--sample-break-mode none \
--tokens-per-sample $TOKENS_PER_SAMPLE \
--batch-size $BATCH_SIZE \
--fp16 --is-moe --distributed-world-size $NUM_GPUS \
--model-overrides "{'world_size': $NUM_GPUS, 'moe_eval_capacity_token_fraction': $MOE_EVAL_CAPACITY_TOKEN_FRACTION}" \
--log-format json
The output of this command looks something like this: (running this command requires a lot of RAM - in my case, peak memory utilisation was around 500+ GB) so this process will OOM on nodes with limited RAM.
2022-11-16 07:05:19 | INFO | fairseq.checkpoint_utils | load_model_ensemble_and_task is_moe=True
2022-11-16 07:05:19 | INFO | fairseq.moe_checkpoint_utils | Found total 64 expert files and current distributed world size: 16, Stitching experts to able to load on current world size.
.
.
.
2022-11-16 07:47:42 | INFO | fairseq_cli.eval_lm | Evaluated <number> tokens in <time>s (<float> tokens/s)
2022-11-16 07:47:42 | INFO | fairseq_cli.eval_lm | test_shifted Loss (base 2): <float>, Perplexity: <float>