maxtext
maxtext copied to clipboard
A simple, performant and scalable Jax LLM!
The code works fine on TPU, crashes on GPU with a strange error. (Look at the test logs below while running train.py)
Tested via creating a multihost_job, simulating a maintenance event, and confirming we can see the logs even after recovering from the simulated maintenance event: Create multihost_job ``` python3 multihost_job.py --COMMAND="bash...
The standalone data loader, sets up the model and data iterator similar to the train_loop of train.py. The data loader iterates through batches of data, to log step time of...
* `--shm-size` is increased to `1g` for `docker run` on GPU because the default value of 64mb might not be sufficient for other set of GPUs (e.g. A100-40gb-8) * `--shm-size=1g`...
* Cloud monitoring prototype * Checkpoint initialization metrics emitting
the assertion doesn't check for determined_val being integer, missing function call.
See if this is an improvement for your purposes. This PR modifies the multihost data put code to infer the global shapes and build `NamedSharding`s lazily at load time. This...
Starting at step 0 results in an error if training is re-run from the checkpoint saved at step 0 or later. This change starts the profile at the end of...