optax
optax copied to clipboard
Adafactor + MultiStep with bfloat16 model doesn't work
If I use Adafactor with MultiStep on a bfloat16 model I get this strange error (note the error is extremely long, so I truncated it to fit in the issue; the model is T5-small):
Traceback (most recent call last):
File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/main.py", line 135, in <module>
train.unroll(metaconfig)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/micro_config.py", line 39, in new_unroll
result = unroll(self, metaconfig)
File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/train_loop.py", line 372, in unroll
logs, params, opt_state = p_step_fn(params, opt_state, new_rng, items)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 352, in wrapped
args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 330, in infer_params
jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 490, in _pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals)
File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/train_loop.py", line 337, in t5_step_fn
updates, opt_state = optim.update(grads, opt_state, params)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/optax/_src/wrappers.py", line 413, in update
new_updates, new_state = jax.lax.cond(
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 252, in cond
return _cond_with_per_branch_args(*ba.args)
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 273, in _cond_with_per_branch_args
return _cond(pred,
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 223, in _cond
_check_tree_and_avals("true_fun and false_fun output",
File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/common.py", line 105, in _check_tree_and_avals
raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
(FrozenDict({
decoder: {
block: {
0: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
relative_attention_bias: {
embedding: 'DIFFERENT ShapedArray(bfloat16[32,8]) vs. ShapedArray(float32[32,8])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
1: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
2: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
3: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
4: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
5: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
2: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
},
final_layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
encoder: {
block: {
0: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
relative_attention_bias: {
embedding: 'DIFFERENT ShapedArray(bfloat16[32,8]) vs. ShapedArray(float32[32,8])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
1: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
2: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
3: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
4: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
5: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
o: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
q: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
v: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
1: {
DenseReluDense: {
wi: {
kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
},
wo: {
kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
},
},
layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
},
},
},
final_layer_norm: {
weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
},
},
shared: {
embedding: 'DIFFERENT ShapedArray(bfloat16[32128,512]) vs. ShapedArray(float32[32128,512])',
},
}), MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(FactoredState(count='ShapedArray(int32[])', v_row=FrozenDict({
decoder: {
block: {
0: {
layer: {
0: {
SelfAttention: {
k: {
kernel: 'ShapedArray(float32[512])',
},
o: {
kernel: 'ShapedArray(float32[512])',
},
q: {
kernel: 'ShapedArray(float32[512])',
},
relative_attention_bias: {
embedding: 'ShapedArray(float32[1])',
},
v: {
kernel: 'ShapedArray(float32[512])',
},
},
layer_norm: {
weight: 'ShapedArray(float32[1])',
},
},
1: {
EncDecAttention: {
k: {
kernel: 'ShapedArray(float32[512])',
},
o: {
kernel: 'ShapedArray(float32[512])',
},
q: {
kernel: 'ShapedArray(float32[512])',
},
v: {
kernel: 'ShapedArray(float32[512])',
},
The error points to this line of optax.MultiSteps. It's essentially saying that mid_step
's first return value has type fp32 but final_step
has type bfloat16. If I force-cast mid_step
's return to bfloat16, the error goes away. And looking at the code, I'm not exactly sure why this would happen; the code looks like it should handle the types correctly. So if anyone has an explanation or a non-hacky fix that would be appreciated.
Note that optimizer is being called inside of a pjit on TPUv3. And I don't get this error with AdamW+MultiStep+bfloat16.