jmp
jmp copied to clipboard
JMP is a Mixed Precision library for JAX.
Following the provided example for the DynamicLossScale causes errors if run directly. ```python import jax import jax.numpy as jnp import jmp dyn = jmp.DynamicLossScale(jnp.float16(2**15)) g = jnp.ones(5, jnp.float16) finite =...
DynamicLossScale.min_loss_scale isn't passed in the DynamicLossScale.tree_flatten outputs which triggered the default factory when calling DynamicLossScale.tree_unflatten. Potentially changing the dtype while doing so (and of course, not keeping the original value).
In the tutorial, you mentioned we should use bf16 for TPU; but does bf16 also work for GPU?