jax_transformer
jax_transformer copied to clipboard
Autoregressive transformer in JAX from scratch
Autoregressive Transformer Decoder in JAX from scratch
This implementation builds a transformer decoder from ground up. This doesn't use any higher level frameworks like Flax and I have used labml for logging and experiment tracking.
I have implemented a simple Module class to build basic building blocks upon.
This was my first JAX project and many implementations were taken from PyTorch implementations at nn.labml.ai.
JAX can optimize and differentiate Python pure-functions. Pure functions are function that take a bunch of arguments and return a result without making changes to anything like local variables. JAX can also compile these functions to as well as vectorize to run them efficiently.
In JAX you don't have to worry about the batches.
The functions are implemented for a single sample and jax.vit can vectorize (parallelize) the functions
across the batch dimension (or any other dimension if needed).