jmp icon indicating copy to clipboard operation
jmp copied to clipboard

JMP is a Mixed Precision library for JAX.

Results 3 jmp issues
Sort by recently updated
recently updated
newest added

Following the provided example for the DynamicLossScale causes errors if run directly. ```python import jax import jax.numpy as jnp import jmp dyn = jmp.DynamicLossScale(jnp.float16(2**15)) g = jnp.ones(5, jnp.float16) finite =...

DynamicLossScale.min_loss_scale isn't passed in the DynamicLossScale.tree_flatten outputs which triggered the default factory when calling DynamicLossScale.tree_unflatten. Potentially changing the dtype while doing so (and of course, not keeping the original value).

In the tutorial, you mentioned we should use bf16 for TPU; but does bf16 also work for GPU?