dalle-mini
dalle-mini copied to clipboard
Training Error: TypeError: true_fun and false_fun output must have identical types
Hello,
I'm attempting to fine tune dalle-mega but hit this error while trying to stand things up:
Traceback (most recent call last):
File "/home/netruk44/ml/workspace/repos/dalle-mini/tools/train/train.py", line 1742, in <module>
main()
File "/home/netruk44/ml/workspace/repos/dalle-mini/tools/train/train.py", line 1702, in main
state, train_metrics = p_train_step(state, batch, train_time)
File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/experimental/pjit.py", line 367, in wrapped
args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/experimental/pjit.py", line 344, in infer_params
jaxpr, normalized_out_shardings_flat = _pjit_jaxpr(
File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/experimental/pjit.py", line 568, in _pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals)
File "/home/netruk44/ml/workspace/repos/dalle-mini/tools/train/train.py", line 1299, in train_step
gradients_norm = maybe_fn(
File "/home/netruk44/ml/workspace/repos/dalle-mini/tools/train/train.py", line 1283, in maybe_fn
return jax.lax.cond(
File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/_src/lax/control_flow/conditionals.py", line 254, in cond
return _cond(*args, **kwargs)
File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/_src/lax/control_flow/conditionals.py", line 223, in _cond
_check_tree_and_avals("true_fun and false_fun output",
File "/home/netruk44/anaconda3/envs/dalle-mini/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 105, in _check_tree_and_avals
raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
FrozenDict({
lm_head: {
kernel: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
},
model: {
decoder: {
embed_positions: {
embedding: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
},
embed_tokens: {
embedding: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
},
final_ln: {
bias: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
},
layernorm_embedding: {
bias: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
scale: 'DIFFERENT ShapedArray(float16[]) vs. ShapedArray(float32[])',
},
},
},
}).
I believe that I have worked around the error by changing this line: https://github.com/borisdayma/dalle-mini/blob/main/tools/train/train.py#L1294
# Old
zeros_norm = jax.tree_util.tree_map(lambda _: jnp.float32(0), params)
# ^^
# Fixed
zeros_norm = jax.tree_util.tree_map(lambda _: jnp.float16(0), params)
# ^^
However, I'm not very familiar with this code or what it's doing. So I don't know if this is the 'correct' solution, but I thought I would at least mention the problem to see if it's anything that might need to be addressed. It's possible I'm configuring something wrong somewhere, so this might just be a personal problem lol.
It looks like the error is likely coming from the choice of checkpoint I passed into train.py
using model_name_or_path
.
Starting fine-tuning using checkpoint dalle-mini/dalle-mini/mega-1-fp16:latest
, I get the error mentioned, but if I fine-tine checkpoint dalle-mini/dalle-mini/mega-1:latest
, it works as-is without any modifications.
I'll leave this issue open in case you'd like to do something with it, but I'd also be fine with just closing it 😄
Oh interesting, thanks for reporting it