maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

A simple, performant and scalable Jax LLM!

Results 159 maxtext issues
Sort by recently updated
recently updated
newest added

Adds ragged attention kernel in Pallas as well as a unit test for the new code. Note that the ragged attention kernel is not actively being used in this code...

- Implemented cudnn flash attention with Transformer Engine - Currently it supports head_dim till 128 and does not support GQA yet. It's an unstable API and would soon change it...

python MaxText/decode.py MaxText/configs/base.yml per_device_batch_size=64 run_name=runner_2024-01-30-20-02 max_prefill_predict_length=128 max_target_length=256 dataset_path=gs://maxtext-dataset async_checkpointing=false scan_layers=false attention=dot_product scan_layers=false ici_autoregressive_parallelism=4 400GB/s/device on a v4-8

Example for changing where profiling starts - here for data loading

Quick prototype to compute Goodput based on total step time and job time on MaxText.

Inference decode configurations intended for CPU for model sizes 1B, 4B, 8B, and 16B parameters.