optax
optax copied to clipboard
Adafactor + MultiStep with bfloat16 model doesn't work
If I use Adafactor with MultiStep on a bfloat16 model I get this strange error (note the error is extremely long, so I truncated it to fit in the issue; the model is T5-small):
Traceback (most recent call last):
File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/main.py", line 135, in <module>
train.unroll(metaconfig)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/micro_config.py", line 39, in new_unroll
result = unroll(self, metaconfig)
File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/train_loop.py", line 372, in unroll
logs, params, opt_state = p_step_fn(params, opt_state, new_rng, items)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 352, in wrapped
args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 330, in infer_params
jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 490, in _pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals)
File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/train_loop.py", line 337, in t5_step_fn
updates, opt_state = optim.update(grads, opt_state, params)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/optax/_src/wrappers.py", line 413, in update
new_updates, new_state = jax.lax.cond(
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 252, in cond
return _cond_with_per_branch_args(*ba.args)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 273, in _cond_with_per_branch_args
return _cond(pred,
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 223, in _cond
_check_tree_and_avals("true_fun and false_fun output",
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/common.py", line 105, in _check_tree_and_avals
raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
(FrozenDict({
decoder: {
block: {
0: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
relative_attention_bias: {
embedding: 'DIFFERENT ShapedArray(bfloat16[32,8]) vs. ShapedArray(float32[32,8])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
1: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
2: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
3: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
4: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
5: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
},
final_layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
encoder: {
block: {
0: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
relative_attention_bias: {
embedding: 'DIFFERENT ShapedArray(bfloat16[32,8]) vs. ShapedArray(float32[32,8])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
1: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
2: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
3: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
4: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
5: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
},
final_layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
shared: {
embedding: 'DIFFERENT ShapedArray(bfloat16[32128,512]) vs. ShapedArray(float32[32128,512])',
},
}), MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(FactoredState(count='ShapedArray(int32[])', v_row=FrozenDict({
decoder: {
block: {
0: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'ShapedArray(float32[512])',
},
o: {
kernel: 'ShapedArray(float32[512])',
},
q: {
kernel: 'ShapedArray(float32[512])',
},
relative_attention_bias: {
embedding: 'ShapedArray(float32[1])',
},
v: {
kernel: 'ShapedArray(float32[512])',
},
},
layer_norm: {
weight: 'ShapedArray(float32[1])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'ShapedArray(float32[512])',
},
o: {
kernel: 'ShapedArray(float32[512])',
},
q: {
kernel: 'ShapedArray(float32[512])',
},
v: {
kernel: 'ShapedArray(float32[512])',
},
The error points to this line of optax.MultiSteps. It's essentially saying that mid_step's first return value has type fp32 but final_step has type bfloat16. If I force-cast mid_step's return to bfloat16, the error goes away. And looking at the code, I'm not exactly sure why this would happen; the code looks like it should handle the types correctly. So if anyone has an explanation or a non-hacky fix that would be appreciated.
Note that optimizer is being called inside of a pjit on TPUv3. And I don't get this error with AdamW+MultiStep+bfloat16.
Interesting, based on your description this would only happen if the dtype inference in line 383 results in the wrong type so I could try looking into whether the dtype returned from optax.scale_by_factored_rms is correct. Do you have a minimal example of the error I could try this with?
Thanks a lot for raising this!
@Sea-Snell, as @mkunesch mentioned it would be helpful to have a minimal example we could try this with
@Sea-Snell I think this issue is not fixed and should be reopened.
Repro:
import os; os.environ['JAX_PLATFORMS'] = 'cpu'
import jax
import jax.numpy as jnp
import optax
@jax.jit
@jax.value_and_grad
def f(params, x, labels):
logits = params @ x
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
return loss.mean()
params = jnp.zeros((5, 18), dtype=jnp.bfloat16)
x = jnp.zeros((18, 4), dtype=jnp.bfloat16)
labels = jnp.zeros((5,), dtype=jnp.uint16)
value, grad = f(params, x, labels)
lr = 0.00005
n_accumulation_steps = 4
optimizer = optax.adafactor(learning_rate=lr)
optimizer = optax.MultiSteps(optimizer, n_accumulation_steps)
opt_state = optimizer.init(params)
updates, opt_state = optimizer.update(grad, opt_state, params)
print(updates)
Error:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "/home/ayaka/llama-2-jax/1.py", line 24, in <module>
updates, opt_state = optimizer.update(grad, opt_state, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/optax/_src/wrappers.py", line 423, in update
new_updates, new_state = jax.lax.cond(
^^^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/conditionals.py", line 286, in cond
return _cond_with_per_branch_args(*ba.args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/conditionals.py", line 307, in _cond_with_per_branch_args
return _cond(pred,
^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/conditionals.py", line 251, in _cond
_check_tree_and_avals("true_fun and false_fun output",
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py", line 202, in _check_tree_and_avals
raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
('DIFFERENT ShapedArray(bfloat16[5,18]) vs. ShapedArray(float32[5,18])', MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(FactoredState(count='ShapedArray(int32[])', v_row='ShapedArray(float32[1])', v_col='ShapedArray(float32[1])', v='ShapedArray(float32[5,18])'), EmptyState(), EmptyState(), EmptyState(), EmptyState()), acc_grads='ShapedArray(bfloat16[5,18])', skip_state=())).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ayaka/llama-2-jax/1.py", line 24, in <module>
updates, opt_state = optimizer.update(grad, opt_state, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ayaka/llama-2-jax/venv/lib/python3.11/site-packages/optax/_src/wrappers.py", line 423, in update
new_updates, new_state = jax.lax.cond(
^^^^^^^^^^^^^
TypeError: true_fun and false_fun output must have identical types, got
('DIFFERENT ShapedArray(bfloat16[5,18]) vs. ShapedArray(float32[5,18])', MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(FactoredState(count='ShapedArray(int32[])', v_row='ShapedArray(float32[1])', v_col='ShapedArray(float32[1])', v='ShapedArray(float32[5,18])'), EmptyState(), EmptyState(), EmptyState(), EmptyState()), acc_grads='ShapedArray(bfloat16[5,18])', skip_state=())).
However, these modifications work:
- Change the optimiser from
optax.adafactortooptax.adamw - Remove
optax.MultiSteps
I had the same problem. I have found that it happens because adafactor returns float32 updates despite params and gradients being bfloat16, while MultiSteps expects them to be of the same type when applying jax.lax.cond. This happens because scale_by_factored_rms inside adafactor does not preserve the type of updates propagating through it. A lot of variables in it's internal state are float32.
One quick fix is to add explicit conversion update.astype(grad.dtype) to this line. If it sounds good, I'd be glad to submit a PR.
As @mk-0 mentioned, this is due to the jax.lax.cond inside MultiSteps. As the error also occurs when using flax.training.dynamic_scale, I think a fix inside MultiSteps would be better. The error occurs basically every time some values of the gradient are of different type compared to the corresponding parameters, i.e. often when some kind of scaling is applied which requires to cast bfloat16 to float32. I'll open a PR with a possible fix.