jmp
jmp copied to clipboard
Incorrect dtype promotion in DynamicLossScale
Following the provided example for the DynamicLossScale causes errors if run directly.
import jax
import jax.numpy as jnp
import jmp
dyn = jmp.DynamicLossScale(jnp.float16(2**15))
g = jnp.ones(5, jnp.float16)
finite = jmp.all_finite(g)
dyn.adjust(~finite)
>> TypeError: lax.select requires arguments to have the same dtypes, got float32, float16. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
This is of course fixed by doing jmp.DynamicLossScale(jnp.float32(2**15))
, but doesn't this defeat the purpose of this object?
What is needed is to construct DynamicLossScale as jmp.DynamicLossScale(jnp.float32(2**15))
and to change loss_scale.py:132 to
return jax.tree_util.tree_map(lambda x: (x * self.loss_scale).astype(x.dtype), tree)
This way gradients are computed in float16 and loss_scale.loss_scale won't overflow after the first 2000 steps (if it is in float16, jmp will increase it to 2**16, which is outside legal range of float16).
What really puzzles me is that this is the only jax mixed precision package that comes up in searches, and it is evidently not just dead, but it has been broken for months and no one cares. Which raises two possibilities:
- Does everyone use jax to train their models strictly in float32 or bf16?
- Does no one use jax any more?
What is needed is to construct DynamicLossScale as
jmp.DynamicLossScale(jnp.float32(2**15))
and to change loss_scale.py:132 toreturn jax.tree_util.tree_map(lambda x: (x * self.loss_scale).astype(x.dtype), tree)
This way gradients are computed in float16 and loss_scale.loss_scale won't overflow after the first 2000 steps (if it is in float16, jmp will increase it to 2**16, which is outside legal range of float16).What really puzzles me is that this is the only jax mixed precision package that comes up in searches, and it is evidently not just dead, but it has been broken for months and no one cares. Which raises two possibilities:
- Does everyone use jax to train their models strictly in float32 or bf16?
- Does no one use jax any more?
I'll try that with a fork of this repo when I have time, thanks for the suggestion!
I think jax is growing in popularity though haha ;p. Though, these open-source projects might not be a deepmind-priority.