jmp icon indicating copy to clipboard operation
jmp copied to clipboard

Incorrect dtype promotion in DynamicLossScale

Open joeryjoery opened this issue 1 year ago • 2 comments

Following the provided example for the DynamicLossScale causes errors if run directly.

import jax
import jax.numpy as jnp
import jmp

dyn = jmp.DynamicLossScale(jnp.float16(2**15))

g = jnp.ones(5, jnp.float16)
finite = jmp.all_finite(g)
dyn.adjust(~finite)
>> TypeError: lax.select requires arguments to have the same dtypes, got float32, float16. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

This is of course fixed by doing jmp.DynamicLossScale(jnp.float32(2**15)), but doesn't this defeat the purpose of this object?

joeryjoery avatar Jan 26 '24 11:01 joeryjoery

What is needed is to construct DynamicLossScale as jmp.DynamicLossScale(jnp.float32(2**15)) and to change loss_scale.py:132 to return jax.tree_util.tree_map(lambda x: (x * self.loss_scale).astype(x.dtype), tree) This way gradients are computed in float16 and loss_scale.loss_scale won't overflow after the first 2000 steps (if it is in float16, jmp will increase it to 2**16, which is outside legal range of float16).

What really puzzles me is that this is the only jax mixed precision package that comes up in searches, and it is evidently not just dead, but it has been broken for months and no one cares. Which raises two possibilities:

  1. Does everyone use jax to train their models strictly in float32 or bf16?
  2. Does no one use jax any more?

ekuznetsov139 avatar May 19 '24 02:05 ekuznetsov139

What is needed is to construct DynamicLossScale as jmp.DynamicLossScale(jnp.float32(2**15)) and to change loss_scale.py:132 to return jax.tree_util.tree_map(lambda x: (x * self.loss_scale).astype(x.dtype), tree) This way gradients are computed in float16 and loss_scale.loss_scale won't overflow after the first 2000 steps (if it is in float16, jmp will increase it to 2**16, which is outside legal range of float16).

What really puzzles me is that this is the only jax mixed precision package that comes up in searches, and it is evidently not just dead, but it has been broken for months and no one cares. Which raises two possibilities:

  1. Does everyone use jax to train their models strictly in float32 or bf16?
  2. Does no one use jax any more?

I'll try that with a fork of this repo when I have time, thanks for the suggestion!

I think jax is growing in popularity though haha ;p. Though, these open-source projects might not be a deepmind-priority.

joeryjoery avatar May 20 '24 15:05 joeryjoery