paxml
paxml copied to clipboard
Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading...
Corrected Grok model parameters to match OSS Grok model
Adds support for NVIDIA's [Transformer Engine](https://github.com/NVIDIA/TransformerEngine). TE can be enabled by setting the environment variable `ENABLE_TE=1`. For more details about running Pax with Transformer Engine, refer to the [JAX Toolbox...
Bumps the pip group with 1 update in the /paxml/pip_package directory: [tensorflow](https://github.com/tensorflow/tensorflow). Updates `tensorflow` from 2.9.3 to 2.11.1 Release notes Sourced from tensorflow's releases. TensorFlow 2.11.1 Release 2.11.1 Note: TensorFlow...
> > Additional GRPC error information from remote target unknown_target_for_coordination_leader while calling /tensorflow.CoordinationService/RegisterTask: > :{"created":"@1712965181.656280441","description":"Deadline Exceeded","file":"external/com_github_grpc_grpc/src/core/ext/filters/deadline/deadline_filter.cc","file_line":69,"grpc_status":4} > 2024-04-12 23:39:41.656900: E external/xla/xla/pjrt/distributed/client.cc:96] Coordination service agent in error status: DEADLINE_EXCEEDED: Deadline Exceeded...
NOT FOR COMMIT Depends on https://github.com/google/praxis/pull/51
This PR is to allow users to enable the cudnn flash attention. The PR depends on https://github.com/google/praxis/pull/53. The preliminary results for the GPT3-5B, we can observe ~30% perf improve on...
I used the aqt_einsum function in the code to only quantify the qk sccore, and then trained the model. However, I found that the loss dropped very slowly after training...
Pip complains without -r before passing in a requirements file.
@zhangqiaorjc please review this change to support per core batch size < 1 with the synthetic dataset.
I'm running paxml on an Intel Xeon CPU server using the paxml/main.py program. I'm trying to create a model that creates weights in bfloat16, and uses that datatype during eval....