maxtext
maxtext copied to clipboard
A simple, performant and scalable Jax LLM!
# [[Bug] adam_pax has reuse donated buffer warning](https://github.com/google/maxtext/issues/490) Reproduced with `weight_dtype=bfloat16` ```shell python3 MaxText/train.py MaxText/configs/base.yml run_name=run steps=10 weight_dtype=bfloat16 opt_type=adam_pax dataset_type=synthetic enable_checkpointing=false ``` ``` /home/lizhiyu/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:914: UserWarning: Some donated buffers were not...
Hi, I noticed that when using `adam_pax` instead of `adamw` as optimizer, it will give `reuse donated buffer` warning. I am wondering if this is expected, and why the code...
Changing checkpointing to use the new API and replacing `default` with `items`. Also added an option to restore checkpoints with SingleReplicaArrayHandler
Hi, I'm trying to understand some details in the TFDS data processing pipeline in your repo, and I'm confused about the following details: **In `_tfds_data_processing.py`:** (1) The `truncate_to_max_allowable_length` function truncates...
Hi, Thanks for the library! I'm new to the JAX+LLM ecosystem and trying to understand which library I should be using. I see a lot of (very impressive) computational efficiency...
Creating a new branch for GCS Tessellation.
Integrate Goodput library with MaxText This PR includes: - Install Goodput dependency (ml-goodput-measurement in requirements.txt) - Add config options to enable Goodput - Update MaxText's train.py to use Goodput APIs...