raymondzouu

Results 2 issues of raymondzouu

Users can use this remat policy by specifying `custom` in the `remat_policy` field and then specifying labeled tensors they would like to be saved into HBM from the forward pass...

Changing l2norm to use jnp.sqrt instead of **0.5. Seeing a speed up on small examples: https://screenshot.googleplex.com/A3GjjWQq5Dhes9b Colab notebook: http://shortn/_p369zYcGI2

pull ready