optax icon indicating copy to clipboard operation
optax copied to clipboard

Adafactor + MultiStep with bfloat16 model doesn't work

Open Sea-Snell opened this issue 1 year ago • 5 comments

If I use Adafactor with MultiStep on a bfloat16 model I get this strange error (note the error is extremely long, so I truncated it to fit in the issue; the model is T5-small):

Traceback (most recent call last):
  File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/main.py", line 135, in <module>
    train.unroll(metaconfig)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/micro_config.py", line 39, in new_unroll
    result = unroll(self, metaconfig)
  File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/train_loop.py", line 372, in unroll
    logs, params, opt_state = p_step_fn(params, opt_state, new_rng, items)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 352, in wrapped
    args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 330, in infer_params
    jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/experimental/pjit.py", line 490, in _pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals)
  File "/home/charliesnell/jax_v_pytorch/large_lm_finetune/flax/train_loop.py", line 337, in t5_step_fn
    updates, opt_state = optim.update(grads, opt_state, params)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/optax/_src/wrappers.py", line 413, in update
    new_updates, new_state = jax.lax.cond(
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 252, in cond
    return _cond_with_per_branch_args(*ba.args)
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 273, in _cond_with_per_branch_args
    return _cond(pred,
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/conditionals.py", line 223, in _cond
    _check_tree_and_avals("true_fun and false_fun output",
  File "/home/charliesnell/miniconda3/envs/jax_v_torch/lib/python3.9/site-packages/jax/_src/lax/control_flow/common.py", line 105, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
(FrozenDict({
    decoder: {
        block: {
            0: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            relative_attention_bias: {
                                embedding: 'DIFFERENT ShapedArray(bfloat16[32,8]) vs. ShapedArray(float32[32,8])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            1: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            2: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            3: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            4: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            5: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    2: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
        },
        final_layer_norm: {
            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
        },
    },
    encoder: {
        block: {
            0: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            relative_attention_bias: {
                                embedding: 'DIFFERENT ShapedArray(bfloat16[32,8]) vs. ShapedArray(float32[32,8])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            1: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            2: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            3: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            4: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
            5: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            o: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            q: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                            v: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,512]) vs. ShapedArray(float32[512,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                    1: {
                        DenseReluDense: {
                            wi: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[512,2048]) vs. ShapedArray(float32[512,2048])',
                            },
                            wo: {
                                kernel: 'DIFFERENT ShapedArray(bfloat16[2048,512]) vs. ShapedArray(float32[2048,512])',
                            },
                        },
                        layer_norm: {
                            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
                        },
                    },
                },
            },
        },
        final_layer_norm: {
            weight: 'DIFFERENT ShapedArray(bfloat16[512]) vs. ShapedArray(float32[512])',
        },
    },
    shared: {
        embedding: 'DIFFERENT ShapedArray(bfloat16[32128,512]) vs. ShapedArray(float32[32128,512])',
    },
}), MultiStepsState(mini_step='ShapedArray(int32[])', gradient_step='ShapedArray(int32[])', inner_opt_state=(FactoredState(count='ShapedArray(int32[])', v_row=FrozenDict({
    decoder: {
        block: {
            0: {
                layer: {
                    0: {
                        SelfAttention: {
                            k: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            o: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            q: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            relative_attention_bias: {
                                embedding: 'ShapedArray(float32[1])',
                            },
                            v: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                        },
                        layer_norm: {
                            weight: 'ShapedArray(float32[1])',
                        },
                    },
                    1: {
                        EncDecAttention: {
                            k: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            o: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            q: {
                                kernel: 'ShapedArray(float32[512])',
                            },
                            v: {
                                kernel: 'ShapedArray(float32[512])',
                            },

The error points to this line of optax.MultiSteps. It's essentially saying that mid_step's first return value has type fp32 but final_step has type bfloat16. If I force-cast mid_step's return to bfloat16, the error goes away. And looking at the code, I'm not exactly sure why this would happen; the code looks like it should handle the types correctly. So if anyone has an explanation or a non-hacky fix that would be appreciated.

Note that optimizer is being called inside of a pjit on TPUv3. And I don't get this error with AdamW+MultiStep+bfloat16.

Sea-Snell avatar Jul 21 '22 14:07 Sea-Snell