paxml icon indicating copy to clipboard operation
paxml copied to clipboard

[Question] Very low MFU(30%~35%) when train bf16 Llama2 and GPT model with single SXM4 A100 machine.

Open MoFHeka opened this issue 5 months ago • 0 comments

I don't know what happened, is the calculation precision and parameter precision not set correctly? Deepspeed or Megatron could achieve 55% MFU easily with same machine. Here is my bash script:

#! /bin/bash
set -u
set -o pipefail

TFDS_DATA_DIR=$1
VOCAB_PATH=$2
PREC=${3:-"bfloat16"}        # Precision (float32, bfloat16)
NUM_GPUS=${4:-8}      # Number of GPUs (1, 2, 4, 8)
PERCORE_BATCH_SIZE=${5:-4}
LOG_DIR=${6:-"test_logdir"}

export VOCAB_PATH=$VOCAB_PATH

BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
                       --xla_gpu_simplify_all_fp_conversions --xla_gpu_enable_async_all_gather=true
                       --xla_gpu_enable_async_reduce_scatter=true  --xla_gpu_enable_highest_priority_async_stream=true
                       --xla_gpu_enable_triton_softmax_fusion=false  --xla_gpu_all_reduce_combine_threshold_bytes=51200
                       --xla_gpu_graph_level=3 --xla_gpu_enable_async_all_reduce=true
                       --xla_gpu_enable_async_collectives=true --xla_gpu_enable_async_collective_permute=true
                       --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
                       --xla_gpu_enable_async_all_to_all=true --xla_gpu_all_reduce_contiguous=true
                       --xla_gpu_all_reduce_blueconnect_num_devices_per_host=true
                       --xla_gpu_enable_cudnn_frontend=true --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true
                       --xla_gpu_enable_cudnn_layer_norm "}
export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"

export ENABLE_TE=1

mkdir -p ${LOG_DIR}
python3 -u -m paxml.main \
    --job_log_dir=${LOG_DIR} \
    --fdl_config=paxml.tasks.lm.params.nvidia.Llama2_7B \
    --fdl.FPROP_DTYPE=\"${PREC}\" \
    --fdl.ICI_MESH_SHAPE="[1,$(expr ${NUM_GPUS}), 1]" \
    --fdl.DCN_MESH_SHAPE="[1,1,1]" \
    --fdl.NUM_STAGES=1 \
    --fdl.MICROBATCH_SIZE=$PERCORE_BATCH_SIZE \
    --fdl.PERCORE_BATCH_SIZE=$PERCORE_BATCH_SIZE \
    --tfds_data_dir=$TFDS_DATA_DIR \
    --alsologtostderr \
    2>&1 | tee ${LOG_DIR}/llama2_7B_output.log

EXP_STATUS=$?

if [ $EXP_STATUS != 0 ]; then
  echo "Run failed"
else
  echo "Run succeeded!"
fi

According https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax, Nvidia train a 5B GPT model with Nativ BF16 in 256 A100 GPU. And its performance 465.45 Sequences/Sec when sequences global batch size is 8*256=2048. So it means it costed 4.4s per step. Am I correct? This script could calculate its MFU which is 38.958427%. It's too low!

# Nvidia Jax GPT5B
card_num=256
gbs=8*card_num
layers=24
num_query=32
num_heads=32
enc_seq_len=2048
hs=4096
ffn_hs=16384
vocab=50304

sequences_per_sec=465.45
seconds_per_step=gbs/sequences_per_sec


#Model total parameters:
params_qkv_state = (1+2*(num_query/num_heads))*hs*hs
params_post_attention_linear = hs*hs
params_fead_forward_network = 2*hs*ffn_hs
params_vocabulary_embedding = hs*vocab


#FPROP:
qkv_state = gbs*2*(1+2*(num_query/num_heads))*enc_seq_len*hs*hs
attention_matrix_computation = gbs*2*enc_seq_len*enc_seq_len*hs
attention_over_values = gbs*2*enc_seq_len*enc_seq_len*hs
post_attention_linear_projection = gbs*2*enc_seq_len*hs*hs
fead_forward_network = gbs*(2*2*enc_seq_len*ffn_hs*hs)
vocabulary_embedding = gbs*2*enc_seq_len*hs*vocab

#BPROP:
#FPROP*2

model_params = (params_qkv_state+params_post_attention_linear+params_fead_forward_network)*layers + params_vocabulary_embedding 
model_float = 3*((qkv_state+attention_matrix_computation+attention_over_values+post_attention_linear_projection+fead_forward_network)*layers + vocabulary_embedding) 
model_flops = model_float/seconds_per_step
cluster_ideal_flops = 312*(10**12) * card_num
MFU = model_flops/cluster_ideal_flops
print("Model parameters {:4f}B MFU={:4f}%".format(model_params/(10**9),MFU*100))

MoFHeka avatar Jan 26 '24 15:01 MoFHeka