rename context to context_autoregressive to seperate CP naming for inference
Description
rename context_parallelism to context_autoregressive_parallelism to separate CP naming for inference
The rest of the description includes relevant details and context, examples:
- training and inference have different sharding stategies to best performance, like to fdsp qkv weights or not.
- introduce context_autoregressive_parallelism as cp axes for inference only.
If the change fixes a bug or a Github issue, please include a link, e.g.,: FIXES: b/405621754
Tests
Inference Tests;
LIBTPU_INIT_ARGS="--xla_tpu_enable_windowed_einsum_for_reduce_scatter=false --xla_jf_spmd_threshold_for_windowed_einsum_mib=1000000" python MaxText/inference_microbenchmark.py MaxText/configs/inference.yml max_prefill_predict_length=1024 max_target_length=2048 model_name=mixtral-8x22b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_context_autoregressive_parallelism=8 scan_layers=false per_device_batch_size=24 attention=dot_product megablox=False quantization=int8 checkpoint_is_quantized=True quantize_kvcache=True capacity_factor=1 tokenizer_path=assets/tokenizer.mistral-v3 compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3 enable_jax_profiler=True inference_microbenchmark_prefill_lengths="128,1024" base_output_directory=$OUTPUT run_name="trillium_22B_baseline" profiler="xplane" sparse_matmul=False model_call_mode=inference
Train Tests:
export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=81920 --xla_tpu_enable_async_all_to_all=true'
python3 MaxText/train.py MaxText/configs/base.yml per_device_batch_size=1 skip_first_n_steps_for_profiler=5 steps=10 max_target_length=4096 tokenizer_path=assets/tokenizer.mistral-v1 dtype=bfloat16 weight_dtype=bfloat16 sparse_matmul=False capacity_factor=1.25 dataset_type=synthetic profiler=xplane ici_fsdp_parallelism=4 attention=flash profiler=xplane model_name=mixtral-8x7b run_name=${WORKLOAD_NAME} sa_block_q=2048 sa_block_kv=2048 sa_block_kv_compute=2048 sa_block_q_dkv=2048 sa_block_kv_dkv=2048 sa_block_kv_dkv_compute=2048 sa_block_q_dq=2048 sa_block_kv_dq=2048
Checklist
Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed.
Qinwen is OOO for a while so I've copied this PR to https://github.com/AI-Hypercomputer/maxtext/pull/1501 so we can merge it soon
any updates on this?
This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions.
This PR was closed because it has been inactive for a while. Please reopen it if you are still working on it.