maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

fix: Pin orbax-checkpoint to v0.10.3 to resolve dependency error (#1273)

Open shota-inoue-lts opened this issue 11 months ago • 4 comments

Situation

I execute following content shell script to train model via TextMax with xpk.

# !/bin/bash
# GCP Settings
PROJECT=XXXXXXX
ZONE=XXXXXXX
CLUSTER=XXXXXXX
TPU_TYPE=v6e-8
NUM_SLICES=1

# Storage path
BASE_OUTPUT_DIR=XXXXXXX
DATASET_PATH=XXXXXXX
DATASET_TYPE=tfds

# HyperParameters
PER_DEVICE_BATCH_SIZE=3
MODEL_NAME=llama3.1-8b
MAX_TARGET_LENGTH=4096
STEPS=35
BLOCK_SIZE=2048
REMAT_POLICY=full
TOKENIZER_PATH=assets/tokenizer_llama3.tiktoken
VMEM_LIMIT=114688
ENABLE_CHECKPOINTING=true
CHECKPOINT_PERIOD=30

# Parallelism
ICI_DATA_PARALLELISM=1
ICI_PIPELINE_PARALLELISM=4
ICI_FSDP_PARALLELISM=1
ICI_FSDP_TRANSPOSE_PARALLELISM=1
ICI_SEQUENCE_PARALLELISM=1
ICI_TENSOR_PARALLELISM=2
ICI_TENSOR_SEQUENCE_PARALLELISM=1
ICI_EXPERT_PARALLELISM=1
ICI_AUTOREGRESSIVE_PARALLELISM=1

# image settings
CLOUD_IMAGE_NAME=${USER}_runner
DOCKER_IMAGE=gcr.io/${PROJECT}/${CLOUD_IMAGE_NAME}:latest

# EXP settings
EXP_NAME=$(echo $MODEL_NAME | tr '.' '-')-bs${PER_DEVICE_BATCH_SIZE}-$(date +'%m-%d-%H-%M-%S') # --workload: Workload name must be less than 40 characters and match the pattern `[a-z]([-a-z0-9]*[a-z0-9])?`

# download dataset
cd ~/maxtext
bash download_dataset.sh ${PROJECT} ${DATASET_PATH}

# create and push image
cd ~/maxtext
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}

# create workload for model training
cd ~/xpk
python3 xpk.py workload create \
    --cluster ${CLUSTER} \
    --docker-image ${DOCKER_IMAGE} \
    --workload ${EXP_NAME} \
    --tpu-type ${TPU_TYPE} \
    --num-slices ${NUM_SLICES}  \
    --use-vertex-tensorboard \
    --experiment-name ${EXP_NAME} \
    --zone ${ZONE} \
    --on-demand \
    --enable-debug-logs \
    --project ${PROJECT} \
    --command "export LIBTPU_INIT_ARGS='--xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_scoped_vmem_limit_kib=${VMEM_LIMIT} --xla_tpu_enable_async_collective_fusion=true --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true' && python3 MaxText/train.py MaxText/configs/base.yml model_name=${MODEL_NAME} base_output_directory=${BASE_OUTPUT_DIR} dataset_path=${DATASET_PATH} run_name=${EXP_NAME} tokenizer_path=${TOKENIZER_PATH} max_target_length=${MAX_TARGET_LENGTH} per_device_batch_size=${PER_DEVICE_BATCH_SIZE} remat_policy=${REMAT_POLICY} steps=${STEPS} enable_checkpointing=${ENABLE_CHECKPOINTING} checkpoint_period=${CHECKPOINT_PERIOD} use_iota_embed=true gcs_metrics=true dataset_type=${DATASET_TYPE} reuse_example_batch=1 profiler=xplane attention=flash sa_block_q=${BLOCK_SIZE} sa_block_q_dkv=${BLOCK_SIZE} sa_block_q_dq=${BLOCK_SIZE} ici_data_parallelism=${ICI_DATA_PARALLELISM} ici_pipeline_parallelism=${ICI_PIPELINE_PARALLELISM} ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} ici_fsdp_transpose_parallelism=${ICI_FSDP_TRANSPOSE_PARALLELISM} ici_sequence_parallelism=${ICI_SEQUENCE_PARALLELISM} ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} ici_tensor_sequence_parallelism=${ICI_TENSOR_SEQUENCE_PARALLELISM} ici_expert_parallelism=${ICI_EXPERT_PARALLELISM} ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM}"

Error Message

I got the following error during process of MaxText/train.py. Especially, the error occur if I activate a checkpoint setting (ENABLE_CHECKPOINTING=true).

"'Traceback (most recent call last):
File ""/deps/MaxText/train.py"", line 1031, in <module>
app.run(main)
File ""/usr/local/lib/python3.10/site-packages/absl/app.py"", line 308, in run
_run_main(main, args)
File ""/usr/local/lib/python3.10/site-packages/absl/app.py"", line 254, in _run_main
sys.exit(main(argv))
File ""/deps/MaxText/train.py"", line 1027, in main
train_loop(config)
File ""/deps/MaxText/train.py"", line 897, in train_loop
if save_checkpoint(checkpoint_manager, int(step), state_to_save, config.dataset_type, data_iterator, config):
File ""/deps/MaxText/train.py"", line 241, in save_checkpoint
return checkpoint_manager.save(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py"", line 1278, in save
self._checkpointer.save(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py"", line 491, in save
asyncio_utils.run_sync(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/asyncio_utils.py"", line 50, in run_sync
return asyncio.run(coro)
File ""/usr/local/lib/python3.10/asyncio/runners.py"", line 44, in run
return loop.run_until_complete(main)
File ""/usr/local/lib/python3.10/asyncio/base_events.py"", line 649, in run_until_complete
return future.result()
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py"", line 392, in _save
await self._handler.async_save(tmpdir.get(), args=ckpt_args) or []
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py"", line 706, in async_save
jax.tree.flatten(await asyncio.gather(*save_ops))[0] or []
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py"", line 583, in async_save
return await self._handler_impl.async_save(directory, args=args)
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py"", line 482, in async_save
commit_futures = await asyncio.gather(*serialize_ops)
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py"", line 1127, in serialize
future.CommitFutureAwaitingContractedSignals(
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/futures/future.py"", line 367, in init
receive_signals = get_awaitable_signals_from_contract()
File ""/usr/local/lib/python3.10/site-packages/orbax/checkpoint/_src/futures/future.py"", line 57, in get_awaitable_signals_from_contract
values_str = str(client.key_value_try_get(barrier_key))
AttributeError: 'DistributedRuntimeClient' object has no attribute 'key_value_try_get'. Did you mean: 'key_value_dir_get'?"

Solution

We should install specific package version orbax-checkpoint==0.10.3 (Now orbax-checkpoint==0.11.5 will be installed without version specification) when we create docker image. We solved the problem by rewriting these requirements file (requirements_with_jax_stable_stack.txt, requirements_with_jax_stable_stack.txt).

# maxtext/requirements_with_jax_stable_stack.txt
...
orbax-checkpoint==0.10.3
...
# maxtext/requirements.txt
...
orbax-checkpoint==0.10.3
...

Reference

I referred the following URLs when I create the shell script.

How to run MaxText with XPK? https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Run_MaxText_via_xpk.md

shota-inoue-lts avatar Feb 14 '25 13:02 shota-inoue-lts

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar Feb 14 '25 13:02 google-cla[bot]

Hi @shota-inoue-lts ,

Thank you so much for looking into this! We are working on fixing this issue by updating the dependencies without needing to make a change in MaxText. In the meantime, you could also try updating your Jax version to the following and it should resolve the error:

jax==0.5.0
jaxlib==0.5.0
jaxtyping==0.2.38

A9isha avatar Feb 20 '25 09:02 A9isha

Hi @A9isha , Thank you for your suggestion to fix the problem more easily!

shota-inoue-lts avatar Feb 20 '25 11:02 shota-inoue-lts

@A9isha @shota-inoue-lts if the problem is fixed, can this PR and corresponding issue be closed now ?

shralex avatar May 01 '25 13:05 shralex