transformers
transformers copied to clipboard
Gradient accumulation trick and Activation Checkpointing feature
Feature request
- Adds gradient accumulation trick to https://github.com/huggingface/transformers/blob/main/examples/flax/summarization/run_summarization_flax.py
- Adds Activation Checkpointing feature
Motivation
For GPU memory issue as well as faster training process.
In the next Your contribution
column, might I ask if the extra if-else
block makes sense OR do we even need optax.apply_every()
for gradient accumulation ?
Your contribution
The following jax
code is modified from original huggingface version
batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
# add gradient accumulation
if training_args.gradient_accumulation_steps > 1:
optimizer = optax.chain(
optax.apply_every(batch_size_per_update), optimizer
)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
if len(accumulated_gradients) < training_args.gradient_accumulation_steps:
accumulated_gradients.append(grad)
new_state = state
else:
grad = jax.tree_multimap(lambda *x: jnp.sum(jnp.stack(x), axis=0), *accumulated_gradients)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
accumulated_gradients = []
cc @sanchit-gandhi
Hey @buttercutter! It looks like there are two different feature requests going on here! Let's focus on the JAX gradient accumulation one since this more relevant to the 'motivation' and code snippet you've provided. Feel free to open a separate issue for DeepSpeed activation checkpointing.
Unfortunately, gradient accumulation in JAX isn't as straightforward as using optax.apply_every
! If you dig through the source code, you'll actually find that using apply_every
with a batch size of N/2 and 2 accumulation steps is not necessarily equivalent to not using apply_every
with a batch size of N. See https://optax.readthedocs.io/en/latest/api.html#optax.apply_every
There is an alternative in optax.MultiSteps
: https://optax.readthedocs.io/en/latest/api.html#optax.MultiSteps. This will give correct gradient equivalence between using gradient accumulation and not using gradient accumulation. However in my experiments, I found it to be not super memory efficient, and consequently quite an unreliable means of using gradient accumulation. For this reason, I took the decision not to add it to the examples scripts.
Feel free to experiment with using optax.MultiSteps
in your code! If you're able to get nice performance, we can explore adding it to the examples scripts! It'd be cool to benchmark the maximum permissible batch size you get without gradient accumulation, and then the maximum effective batch size you get with gradient accumulation!
In my experiments, the most memory efficient way of implementing gradient accumulation was to to write a custom loop: https://github.com/sanchit-gandhi/seq2seq-speech/blob/669e51452c396b3b8605c9ac7511da8abe31038f/run_flax_speech_recognition_seq2seq.py#L1352 Now while this is the most memory efficient way, it's the most complicated in terms of code understanding! For this reason, it's also not a good fit for the Transformers examples scripts, which we try and keep as clean and lightweight as possible.
I am using your custom loop for train_step()
, but I have the following error:
Note: In my code, training_args.per_device_gradient_accumulation_steps = 10
, and training_args.per_device_train_batch_size = 8
and batch
has shape of (8, 3600)
Traceback (most recent call last):
File "run_summarization_flax.py", line 1338, in <module>
main()
File "run_summarization_flax.py", line 1264, in main
state, train_metric = p_train_step(state, batch)
File "/home/moe/.local/lib/python3.8/site-packages/chex/_src/fake.py", line 175, in wrapped_fn
output = vmapped_fn(*call_args)
File "run_summarization_flax.py", line 1173, in train_step
batch = jax.tree_map(
File "run_summarization_flax.py", line 1174, in <lambda>
lambda x: x.reshape(
File "/home/moe/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 793, in _reshape
return lax.reshape(a, newshape, None)
jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (8, 8, 3600) and (8, 10, 8, 3600)
@sanchit-gandhi When I run your original python script without any modifications, it gave free(): invalid pointer
?
And when I use run_librispeech.sh , it gave similar error on free()
again.
sh run_librispeech.sh
src/tcmalloc.cc:332] Attempt to free invalid pointer 0x7fc48dd90558
Aborted (core dumped)
@sanchit-gandhi I am not able to use your original python script, hence I proceed with my own python script with the following slight modification to get it past the dimension runtime error.
Note that the -1
in the reshape
operation means that the size of the last dimension will be inferred from the size of x
and the other dimensions. Hence the following modification will reshape batch
to have shape (8, 10, 3600)
# add a first dimension over gradient_accumulation_steps for minibatch slices
batch = jax.tree_map(
lambda x: x.reshape(
training_args.per_device_train_batch_size, training_args.per_device_gradient_accumulation_steps, -1 #*x.shape[1::]
),
batch,
)
Hey @buttercutter! Sorry for the late reply here!
The shape mismatch error you are experiencing is likely due to a difference in the number of accelerator devices. I purposed my script for a TPU v3-8 (8 devices), whereas it looks like you're testing on a single GPU (1 device).
With multiple devices, we shard the data across devices by prepending an extra dimension to the start of the data: (num_devices, per_device_train_batch_size, input_shape)
.
We don't get this extra dimension with one device: since we run everything on a single GPU, there is no need for any data sharding. This is probably the reason for the shape mis-match we are seeing here (your data is of shape (per_device_train_batch_size, input_shape)
). The workaround with setting -1
in the reshape operation looks valid in this case!
Glad to see the script is working now! Let me know if you encounter any further issues - more than happy to help here!
Hey @sanchit-gandhi
How to properly modify line 1208 till line 1230 for enabling gradient accumulation trick ?
I have turned off training_args.gradient_checkpointing
option for now because of the following runtime error. Could you also help to advise on this as well ?
All the weights of FlaxLongT5ForConditionalGeneration were initialized from the model checkpoint at google/long-t5-tglobal-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use FlaxLongT5ForConditionalGeneration for predictions without further training.
Traceback (most recent call last):
File "run_summarization_flax.py", line 1340, in <module>
main()
File "run_summarization_flax.py", line 605, in main
model.enable_gradient_checkpointing()
AttributeError: 'FlaxLongT5ForConditionalGeneration' object has no attribute 'enable_gradient_checkpointing'
It seems that AttributeError: 'FlaxLongT5ForConditionalGeneration' object has no attribute 'enable_gradient_checkpointing'
is gone after forced reinstall of transformers library.
The only issue left is the gradient accumulation
Hey @buttercutter,
For such specific questions, it really helps to provide a reproducible code-snippet, such that the maintainer looking into the issue can replicate the error being faced and dig into the code on their end locally.
In this case, I created one that uses a 'tiny random' version of the BART model so that the forward/backward passes are fast, and a 'mini' version of the XSUM dataset such that the dataset download and preparation time is small:
python run_summarization_flax.py \
--output_dir="./" \
--model_name_or_path="sshleifer/bart-tiny-random" \
--tokenizer_name="sshleifer/bart-tiny-random" \
--dataset_name="iohadrubin/mini_xsum" \
--do_train \
--do_eval \
--predict_with_generate \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--overwrite_output_dir \
--max_source_length="64" \
--max_target_length 32 \
I would highly recommend this approach of using tiny/mini versions of the model/dataset when debugging to give a fast feedback loop! Having tiny/mini versions is also good practice when sharing your code, as it allows others to try the code out locally without enormous download and wait times.
The easiest thing to do would be to remove all the layer/grad norm logs if you don't need them (L1208-1225). Otherwise, you can follow this fix.
Upon inspection, the keys for the layer_grad_norm
and layer_param_norm
need to be changed for the BART model to include an extra key. The layer grad norm values then need to be made into a jnp.array
:
logs = {
"layer_grad_norm": layer_grad_norm,
- "encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
+ "encoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["encoder"]))),
- "decoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["decoder"])),
+ "decoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["decoder"]))),
}
Here's the full corrected code snippet:
# compute gradient norms over all layers, total encoder, total decoder and global for detailed monitoring
layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
logs = {
"layer_grad_norm": layer_grad_norm,
"encoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["encoder"]))),
"decoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["decoder"]))),
}
logs["grad_norm"] = jnp.linalg.norm([logs["encoder_grad_norm"], logs["decoder_grad_norm"]])
# compute parameter norms over all layers, total encoder, total decoder and global for detailed monitoring
layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
logs["layer_param_norm"] = layer_param_norm
logs["encoder_param_norm"] = jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_param_norm["model"]["encoder"])))
logs["decoder_param_norm"] = jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_param_norm["model"]["decoder"])))
logs["param_norm"] = jnp.linalg.norm([logs["encoder_param_norm"], logs["decoder_param_norm"]])
Hope that helps!
@sanchit-gandhi
model
key seems not found ?
Let me also do some debugging at the same time.
Traceback (most recent call last):
File "run_summarization_flax.py", line 1341, in <module>
main()
File "run_summarization_flax.py", line 1270, in main
state, train_metric = p_train_step(state, batch)
File "/home/moe/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/moe/.local/lib/python3.8/site-packages/jax/_src/api.py", line 2253, in cache_miss
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 974, in xla_pmap_impl_lazy
compiled_fun, fingerprint = parallel_callable(
File "/home/moe/.local/lib/python3.8/site-packages/jax/linear_util.py", line 303, in memoized_fun
ans = call(fun, *args)
File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1245, in parallel_callable
pmap_computation = lower_parallel_callable(
File "/home/moe/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1414, in lower_parallel_callable
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1321, in stage_parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/home/moe/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2065, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/moe/.local/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "run_summarization_flax.py", line 1214, in train_step
"encoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["encoder"]))),
jax._src.traceback_util.UnfilteredStackTrace: KeyError: 'model'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "run_summarization_flax.py", line 1341, in <module>
main()
File "run_summarization_flax.py", line 1270, in main
state, train_metric = p_train_step(state, batch)
File "run_summarization_flax.py", line 1214, in train_step
"encoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["encoder"]))),
KeyError: 'model'
@sanchit-gandhi I did a print on layer_grad_norm
, and it seems that model
is not one of the key.
Could you advise ?
layer_grad_norm = {'decoder': {'block': {'0': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'relative_attention_bias': {'embedding':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '1': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '10': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '11': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '2': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '3': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '4': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '5': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '6': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '7': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '8': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '9': {'layer': {'0': {'SelfAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}}, 'final_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'encoder': {'block': {'0': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'global_relative_attention_bias': {'embedding':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'relative_attention_bias': {'embedding':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '1': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '10': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '11': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '2': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '3': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '4': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '5': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '6': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '7': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '8': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '9': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}}, 'final_layer_norm': {'weight':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'lm_head': {'kernel':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'shared': {'embedding':
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}
Hey @buttercutter,
Unless you're really keen for grad/param norms and have your logger set-up for this, the cleanest thing to do would be to strip the grad/param norm code out of the train step. Otherwise it adds unnecessary computations for results that you won't be analysing!
I can't reproduce your code snippet, but it looks like the model you're using has one less model
key in its params than the dummy one from my code snippet. If you're set on keeping the logging code in, we need to update the dict references accordingly:
# compute gradient norms over all layers, total encoder, total decoder and global for detailed monitoring
layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad)
logs = {
"layer_grad_norm": layer_grad_norm,
"encoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["encoder"]))),
"decoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["decoder"]))),
}
logs["grad_norm"] = jnp.linalg.norm(jnp.array([logs["encoder_grad_norm"], logs["decoder_grad_norm"]]))
# compute parameter norms over all layers, total encoder, total decoder and global for detailed monitoring
layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params)
logs["layer_param_norm"] = layer_param_norm
logs["encoder_param_norm"] = jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_param_norm["encoder"])))
logs["decoder_param_norm"] = jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_param_norm["decoder"])))
logs["param_norm"] = jnp.linalg.norm(jnp.array([logs["encoder_param_norm"], logs["decoder_param_norm"]]))
@sanchit-gandhi
I just confirmed that the suggested code changes to properly include logs["grad_norm"]
and logs["param_norm"]
actually caused OOM error on TPU.
Epoch ... (1/16): 0%| | 0/16 [07:05<?, ?it/s]
Traceback (most recent call last):
File "run_summarization_flax.py", line 1339, in <module>
main()
File "run_summarization_flax.py", line 1268, in main
state, train_metric = p_train_step(state, batch)
ValueError: RESOURCE_EXHAUSTED: Attempting to allocate 382.18M. That was not possible. There are 375.16M free.; (0x0x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
That's probably because training is working now and we're managing to run the script past the previous error no? As mentioned, feel free to remove all the logger code if you're not interested in tracking param/grad norms (this will save you a bit of memory).
Then you can try reducing your per_device_train_batch_size
by factors of 2 and increasing gradient_accumulation_steps
to compensate (i.e. try halving per_device_train_batch_size
and doubling gradient_accumulation_steps
until you can run the script without OOMs). We're now into the classic phase of finding a suitable training batch size for our model and accelerator device
@sanchit-gandhi
I had reduced to even the smallest possible value for per_device_gradient_accumulation_steps=2
with per_device_train_batch_size=1
, but it still give memory resource exhaustion OOM error.
Note: Removing all the logger code you provided earlier cleared this OOM error though.
Hey @buttercutter! Awesome, if gradient accumulation is working without the logging code it sounds like we're in a good position 🚀 I'll close this issue unless there's anything else regarding grad accumulation you wanted to ask!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.