jax_transformer icon indicating copy to clipboard operation
jax_transformer copied to clipboard

Autoregressive transformer in JAX from scratch

Autoregressive Transformer Decoder in JAX from scratch

This implementation builds a transformer decoder from ground up. This doesn't use any higher level frameworks like Flax and I have used labml for logging and experiment tracking.

I have implemented a simple Module class to build basic building blocks upon.

This was my first JAX project and many implementations were taken from PyTorch implementations at nn.labml.ai.

JAX can optimize and differentiate Python pure-functions. Pure functions are function that take a bunch of arguments and return a result without making changes to anything like local variables. JAX can also compile these functions to as well as vectorize to run them efficiently.

In JAX you don't have to worry about the batches. The functions are implemented for a single sample and jax.vit can vectorize (parallelize) the functions across the batch dimension (or any other dimension if needed).

Contents

View Run Twitter thread