maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

converted mlperf gpt3 ckpt starts with a worse loss

Open gramesh-amd opened this issue 5 months ago • 26 comments

Hello, We converted the paxml checkpoint and resumed training with following config:

base_config: "base.yml"
tokenizer_path: "/dockerx/vocab/c4_en_301_5Mexp2_spm.model"
dataset_type: "tfds"
dataset_path: "/ckpts/c4_mlperf_dataset"
dataset_name: "en:3.0.4"
eval_dataset_name: "en:3.0.5"
split: "train2"
tokenize_eval_data: False
eval_data_column: "ids"
add_bos: False
add_eos: False
eval_split: "validation_tokenized_5662seqs"
eval_interval: 10  # the specific number of train step between eval_step
target_eval_loss: 2.69  # early stop once reaching target eval_loss

enable_checkpointing: True
save_interval_steps: 5

# Args coming from the NVIDIA spreadsheet http://shortn/_W9CzVbtQde and
# third_party/py/maxtext/configs/a3/llama_2_7b.
hardware: "gpu"
steps: 10
model_name: "gpt3-175b" # this model config is unchanged
attention: "cudnn_flash_te"

gradient_accumulation_steps: 1

dcn_data_parallelism: 1
dcn_fsdp_parallelism: -1
dcn_pipeline_parallelism: 1
dcn_tensor_parallelism: 1
dcn_sequence_parallelism: 1
ici_fsdp_parallelism: 8
ici_data_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_pipeline_parallelism: 1
per_device_batch_size: 5
max_target_length: 2048

remat_policy: "full"
use_iota_embed: True
scan_layers: False
async_checkpointing: False
logits_dot_in_fp32: False
megablox: False

dtype: "bfloat16"
quantization: ""
quantize_kvcache: False
kv_quant_axis: "heads_and_dkv"
kv_quant_dtype: "int8"
weight_dtype: bfloat16
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint

skip_first_n_steps_for_profiler: 3

mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
logical_axis_rules: [
                      ['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
                       # For pipeline parallelism the pre and post decoder layer tensors' batch dimension is sharded by stages.
                       # Microbatches are sharded by stage, so moving out of and into this sharding should be a local reshape.
                       # The "stage" needs to be listed first since the microbatch dimension is first before the reshape.
                      ['activation_embed_and_logits_batch', ['stage', 'data', 'fsdp', 'fsdp_transpose']],
                      ['activation_heads', ['tensor','sequence']],
                      ['activation_kv_heads', ['tensor','sequence']],
                      ['activation_length', 'sequence'],
                      ['activation_embed', 'tensor'],
                      ['activation_mlp', 'tensor'],
                      ['activation_kv', 'tensor'],
                      ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose',]],
                      ['activation_kv_head_dim', 'tensor'],
                      ['activation_vocab', ['tensor', 'sequence']],
                      ['activation_vocab', 'tensor'],
                      ['activation_vocab', 'sequence'],
                      ['activation_stage','stage'],
                      ['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
                      ['vocab', ['tensor', 'autoregressive']],
                      ['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
                      ['embed', ['fsdp', 'sequence']],
                      ['norm', 'fsdp'],
                      ['heads', ['tensor', 'autoregressive']],
                      ['layers', 'stage'],
                      ['kv', []],
                      ['kv_heads', ['tensor', 'autoregressive']],
                      ['kv_head_dim', []],
                      ['cache_batch', []],
                      ['cache_heads', ['autoregressive', 'tensor']],
                      ['cache_kv', []],
                      ['cache_sequence', []],
                    ]

# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']]

The tokenizer and data splits (3.0.4, 3.0.5) were downloaded from mlperf2 bucket. I have also tried using the c4_mlperf dataset_type like this:

base_config: "base.yml"
tokenizer_path: "/dockerx/vocab/c4_en_301_5Mexp2_spm.model"
dataset_type: "c4_mlperf"
dataset_path: "/ckpts/c4_mlperf_dataset"
dataset_name: "en:3.0.4"
eval_dataset_name: "en:3.0.5"
split: "train2"
eval_split: "validation_tokenized_5662seqs"
python maxtext/MaxText/train.py /dockerx/maxtext/MaxText/configs/gpt3_175b_gpu.yml base_output_directory=/ckpts/paxml/gpt3-conversion run_name=gpt3-conversion steps=4010 scan_layers=true

^ scan_layers set to true in line with how we converted the ckpt

completed step: 4000, seconds: 91.772, TFLOP/s/device: 24.021, Tokens/s/device: 22.316, total_weights: 65504, loss: 7.644, perplexity: 2088.295
To see full metrics 'tensorboard --logdir=/ckpts/paxml/gpt3-conversion/gpt3-conversion/tensorboard/'
completed step: 4001, seconds: 12.945, TFLOP/s/device: 170.297, Tokens/s/device: 158.213, total_weights: 65504, loss: 7.687, perplexity: 2179.917
completed step: 4002, seconds: 11.886, TFLOP/s/device: 185.471, Tokens/s/device: 172.310, total_weights: 65504, loss: 7.739, perplexity: 2297.215
completed step: 4003, seconds: 11.885, TFLOP/s/device: 185.479, Tokens/s/device: 172.318, total_weights: 65504, loss: 7.597, perplexity: 1992.680
completed step: 4004, seconds: 11.931, TFLOP/s/device: 184.759, Tokens/s/device: 171.649, total_weights: 65504, loss: 7.680, perplexity: 2165.097
completed step: 4005, seconds: 11.913, TFLOP/s/device: 185.043, Tokens/s/device: 171.912, total_weights: 65504, loss: 7.663, perplexity: 2128.778
completed step: 4006, seconds: 11.945, TFLOP/s/device: 184.546, Tokens/s/device: 171.451, total_weights: 65504, loss: 7.582, perplexity: 1963.248
completed step: 4007, seconds: 11.913, TFLOP/s/device: 185.048, Tokens/s/device: 171.918, total_weights: 65504, loss: 7.648, perplexity: 2096.574
completed step: 4008, seconds: 12.013, TFLOP/s/device: 183.498, Tokens/s/device: 170.478, total_weights: 65504, loss: 7.524, perplexity: 1851.645
completed step: 4009, seconds: 11.920, TFLOP/s/device: 184.929, Tokens/s/device: 171.807, total_weights: 65504, loss: 7.618, perplexity: 2034.629

^ starts with a very high loss and we expected something closer to 2.77

We have ensured that the training loads the right checkpoint, the correct data splits and also the tokenizer from the logs

gramesh-amd avatar Sep 13 '24 03:09 gramesh-amd