add Fuji v3 405b and solve HBM OOMs for larger models
Main things changed as part of this:
- Change Embedding partition spec from (None, "model") to ("fsdp", "model")
- Add a new remat checkpoint on TransformerLayer input and offload to host
reference: https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/configs/models/llama3.1-405b.yml
Requires optimizer state weight only to be merged for the fsdp=256 data=-1 config: https://github.com/apple/axlearn/pull/789
@kelvin-zou @hanzhi713 would appreciate your review to make sure this PR roughly matches 405B. Thank you!
Getting this error:
NotFoundError: The specified path gs://axlearn-public/tensorflow_datasets/tokenizers/sentencepiece/bpe
_128k_c4.model was not found.
am I doing something wrong or is there a missing tokenizer?
gsutil ls -r -l gs://axlearn-public/tensorflow_datasets/tokenizers/sentencepiece/
gs://axlearn-public/tensorflow_datasets/tokenizers/sentencepiece/:
778888 2023-11-22T23:02:56Z gs://axlearn-public/tensorflow_datasets/tokenizers/sentencepiece/bpe_32k_c4.model
Fixed the issue after vocab model was uploaded. Now I'm hitting OOM issues. Here is the model config:
max_step: 3932160
mesh_axis_names[0]: 'pipeline'
mesh_axis_names[1]: 'data'
mesh_axis_names[2]: 'expert'
mesh_axis_names[3]: 'fsdp'
mesh_axis_names[4]: 'seq'
mesh_axis_names[5]: 'model'
mesh_rules[0][0]: 'tpu-v6e-.*'
mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1
mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256
mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1
mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.nothing_saveable'
mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
mesh_shape[3]: 256
mesh_shape[4]: 1
mesh_shape[5]: 1
model.batch_axis_names[0]: 'data'
model.batch_axis_names[1]: 'expert'
model.batch_axis_names[2]: 'fsdp'
model.decoder.attention_mask: None
model.decoder.dim: 53248
model.decoder.dropout_rate: 0.0
model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings'
model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0
model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.decoder.emb.token_emb.param_partition_spec[0]: None
model.decoder.emb.token_emb.param_partition_spec[1]: 'model'
model.decoder.eos_token_id: 1
model.decoder.klass: 'axlearn.common.decoder.Decoder'
model.decoder.logits_partition_spec[0][0]: 'data'
model.decoder.logits_partition_spec[0][1]: 'expert'
model.decoder.logits_partition_spec[0][2]: 'fsdp'
model.decoder.logits_partition_spec[1]: 'seq'
model.decoder.logits_partition_spec[2]: 'model'
model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.output_norm.eps: 1e-05
model.decoder.output_norm.forward_dtype: None
model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.pad_token_id: 0
model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer'
model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu'
model.decoder.transformer.layer.feed_forward.activation[1]: 'linear'
model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn'
model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256
model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 2.6666666666666665
model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer'
model.decoder.transformer.layer.feed_forward.linear1.bias: False
model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.bias: False
model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq'
model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05
model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None
model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.transformer.layer.feed_forward.residual_weight: 1.0
model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.feed_forward.structure: 'prenorm'
model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer'
model.decoder.transformer.layer.remat_spec['prevent_cse']: True
model.decoder.transformer.layer.remat_spec['policy']: 'jax._src.ad_checkpoint.nothing_saveable'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8
model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding'
model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0
model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False
model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey'
model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None
model.decoder.transformer.layer.self_attention.attention.num_heads: 128
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None
model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False
model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None
model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery'
model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512
model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer'
model.decoder.transformer.layer.self_attention.norm.eps: 1e-05
model.decoder.transformer.layer.self_attention.norm.forward_dtype: None
model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.self_attention.structure: 'prenorm'
model.decoder.transformer.num_layers: 126
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in'
model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer'
model.param_init.init_by_param_name['.*weight$'].scale: 1.0
model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.seq_axis_names[0]: 'seq'
model.z_loss_scale: 0.0
name: 'gpt_trainer'
prune_empty_state_updates: True
recorder.fn: '__main__.<lambda>'
save_input_iterator: False
start_trace_process_indices[0]: 0
start_trace_steps[0]: 16
summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter'
summary_writer.max_queue: 1000
summary_writer.write_every_n_steps: 100
train_dtype: 'jax.numpy.bfloat16'
This PR is still in draft mode. I will be able to update it once I get 405B working on trillium.
The implementation is wrong, hidden_dim should be 16384 and ffn_dim should be 53248 right?
I will update this PR once I get my trillium-405b branch working.
Unable to run golden_config_test:
_______________________ ERROR collecting axlearn/experiments/golden_config_test.py ________________________
ImportError while importing test module '/Users/stoelinga/workspace/axlearn/axlearn/experiments/golden_config_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../miniforge3/envs/axlearn-10/lib/python3.10/importlib/__init__.py:126: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
axlearn/experiments/golden_config_test.py:19: in <module>
from axlearn.experiments.audio import conformer
axlearn/experiments/audio/conformer/__init__.py:5: in <module>
from . import librispeech_trainer
axlearn/experiments/audio/conformer/librispeech_trainer.py:36: in <module>
from axlearn.audio.encoder_asr import SpeechFeatureLayer
axlearn/audio/encoder_asr.py:18: in <module>
from axlearn.audio.subsamplers import ConvSubSampler
axlearn/audio/subsamplers.py:10: in <module>
from axlearn.common.layers import BaseNormalizationLayer, Conv2DWith1DPadding, get_activation_fn
axlearn/common/layers.py:52: in <module>
from axlearn.common.quantized_dot_general.layers import DenseGeneralBaseLayer
axlearn/common/quantized_dot_general/layers.py:29: in <module>
from aqt.jax.v2.config import DotGeneral, set_context
E ImportError: cannot import name 'set_context' from 'aqt.jax.v2.config' (/Users/stoelinga/miniforge3/envs/axlearn-10/lib/python3.10/site-packages/aqt/jax/v2/config.py)
------------- generated xml file: /Users/stoelinga/workspace/axlearn/test-results/testing.xml -------------
============================================ 1 error in 22.07s ============================================
@kelvin-zou could you give it another review?
I added the TransformerLayer input checkpointing offload to host. This is required in order to run 405B. I did something similar in my full branch: https://github.com/apple/axlearn/compare/main...samos123:axlearn:trillium-405b-offload
this needs to be rebased since shared_lm_head is a new thing in latest main.