paxml
paxml copied to clipboard
[Question] Very low MFU(30%~35%) when train bf16 Llama2 and GPT model with single SXM4 A100 machine.
I don't know what happened, is the calculation precision and parameter precision not set correctly? Deepspeed or Megatron could achieve 55% MFU easily with same machine. Here is my bash script:
#! /bin/bash
set -u
set -o pipefail
TFDS_DATA_DIR=$1
VOCAB_PATH=$2
PREC=${3:-"bfloat16"} # Precision (float32, bfloat16)
NUM_GPUS=${4:-8} # Number of GPUs (1, 2, 4, 8)
PERCORE_BATCH_SIZE=${5:-4}
LOG_DIR=${6:-"test_logdir"}
export VOCAB_PATH=$VOCAB_PATH
BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
--xla_gpu_simplify_all_fp_conversions --xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_highest_priority_async_stream=true
--xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=51200
--xla_gpu_graph_level=3 --xla_gpu_enable_async_all_reduce=true
--xla_gpu_enable_async_collectives=true --xla_gpu_enable_async_collective_permute=true
--xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_async_all_to_all=true --xla_gpu_all_reduce_contiguous=true
--xla_gpu_all_reduce_blueconnect_num_devices_per_host=true
--xla_gpu_enable_cudnn_frontend=true --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true
--xla_gpu_enable_cudnn_layer_norm "}
export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"
export ENABLE_TE=1
mkdir -p ${LOG_DIR}
python3 -u -m paxml.main \
--job_log_dir=${LOG_DIR} \
--fdl_config=paxml.tasks.lm.params.nvidia.Llama2_7B \
--fdl.FPROP_DTYPE=\"${PREC}\" \
--fdl.ICI_MESH_SHAPE="[1,$(expr ${NUM_GPUS}), 1]" \
--fdl.DCN_MESH_SHAPE="[1,1,1]" \
--fdl.NUM_STAGES=1 \
--fdl.MICROBATCH_SIZE=$PERCORE_BATCH_SIZE \
--fdl.PERCORE_BATCH_SIZE=$PERCORE_BATCH_SIZE \
--tfds_data_dir=$TFDS_DATA_DIR \
--alsologtostderr \
2>&1 | tee ${LOG_DIR}/llama2_7B_output.log
EXP_STATUS=$?
if [ $EXP_STATUS != 0 ]; then
echo "Run failed"
else
echo "Run succeeded!"
fi
According https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax, Nvidia train a 5B GPT model with Nativ BF16 in 256 A100 GPU. And its performance 465.45 Sequences/Sec when sequences global batch size is 8*256=2048. So it means it costed 4.4s per step. Am I correct? This script could calculate its MFU which is 38.958427%. It's too low!
# Nvidia Jax GPT5B
card_num=256
gbs=8*card_num
layers=24
num_query=32
num_heads=32
enc_seq_len=2048
hs=4096
ffn_hs=16384
vocab=50304
sequences_per_sec=465.45
seconds_per_step=gbs/sequences_per_sec
#Model total parameters:
params_qkv_state = (1+2*(num_query/num_heads))*hs*hs
params_post_attention_linear = hs*hs
params_fead_forward_network = 2*hs*ffn_hs
params_vocabulary_embedding = hs*vocab
#FPROP:
qkv_state = gbs*2*(1+2*(num_query/num_heads))*enc_seq_len*hs*hs
attention_matrix_computation = gbs*2*enc_seq_len*enc_seq_len*hs
attention_over_values = gbs*2*enc_seq_len*enc_seq_len*hs
post_attention_linear_projection = gbs*2*enc_seq_len*hs*hs
fead_forward_network = gbs*(2*2*enc_seq_len*ffn_hs*hs)
vocabulary_embedding = gbs*2*enc_seq_len*hs*vocab
#BPROP:
#FPROP*2
model_params = (params_qkv_state+params_post_attention_linear+params_fead_forward_network)*layers + params_vocabulary_embedding
model_float = 3*((qkv_state+attention_matrix_computation+attention_over_values+post_attention_linear_projection+fead_forward_network)*layers + vocabulary_embedding)
model_flops = model_float/seconds_per_step
cluster_ideal_flops = 312*(10**12) * card_num
MFU = model_flops/cluster_ideal_flops
print("Model parameters {:4f}B MFU={:4f}%".format(model_params/(10**9),MFU*100))