maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

A simple, performant and scalable Jax LLM!

Results 159 maxtext issues
Sort by recently updated
recently updated
newest added

Can we refactor the imports to make MaxText as Python Modules? It's pretty hard for developers to use or develop on top of it. - Blocking inference development with JetStream....

feature request

Hi, I was testing the multi-host training on a v4-16 TPU VM. The training normally runs smoothly, but sometimes, it collapses at `load_next_batch` with the following error from the process...

bug

1. When local checkpoints are available for restore, alter mesh setup as follows. - Ignore the JAX coordinator provided by XPK and override the JAX coordinator to be the pod...

Adding new feature `gradient accumulation` to only update weight for every x steps. Example command without using `gradient accumulation`: ``` python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${MAXTEXT_OUTPUT_PATH} run_name=${RUN_NAME} enable_checkpointing=false async_checkpointing=false per_device_batch_size=1 skip_first_n_steps_for_profiler=5 steps=30...

Is FlashAttention supported on TPUv3? The same config that works on TPUv4 fails on TPUv3 with the following error: `jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data...

# This is created as a draft PR for GCS internal members to comment. This will not be merged to main. ## Checkpointing a 64B model through MaxText - Read...

Add the `enable_model_warmup` flag at model server start Associated PR: https://github.com/google/JetStream/pull/92 ``` - model_name=gemma-7b - tokenizer_path=assets/tokenizer.gemma - per_device_batch_size=1 - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=-1 - ici_tensor_parallelism=1...