maxtext
maxtext copied to clipboard
A simple, performant and scalable Jax LLM!
Can we refactor the imports to make MaxText as Python Modules? It's pretty hard for developers to use or develop on top of it. - Blocking inference development with JetStream....
Hi, I was testing the multi-host training on a v4-16 TPU VM. The training normally runs smoothly, but sometimes, it collapses at `load_next_batch` with the following error from the process...
1. When local checkpoints are available for restore, alter mesh setup as follows. - Ignore the JAX coordinator provided by XPK and override the JAX coordinator to be the pod...
Adding new feature `gradient accumulation` to only update weight for every x steps. Example command without using `gradient accumulation`: ``` python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${MAXTEXT_OUTPUT_PATH} run_name=${RUN_NAME} enable_checkpointing=false async_checkpointing=false per_device_batch_size=1 skip_first_n_steps_for_profiler=5 steps=30...
Is FlashAttention supported on TPUv3? The same config that works on TPUv4 fails on TPUv3 with the following error: `jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data...
# This is created as a draft PR for GCS internal members to comment. This will not be merged to main. ## Checkpointing a 64B model through MaxText - Read...
Add the `enable_model_warmup` flag at model server start Associated PR: https://github.com/google/JetStream/pull/92 ``` - model_name=gemma-7b - tokenizer_path=assets/tokenizer.gemma - per_device_batch_size=1 - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=-1 - ici_tensor_parallelism=1...