transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Gradient accumulation trick and Activation Checkpointing feature

Open buttercutter opened this issue 2 years ago • 18 comments

Feature request

  1. Adds gradient accumulation trick to https://github.com/huggingface/transformers/blob/main/examples/flax/summarization/run_summarization_flax.py
  2. Adds Activation Checkpointing feature

Motivation

For GPU memory issue as well as faster training process.

In the next Your contribution column, might I ask if the extra if-else block makes sense OR do we even need optax.apply_every() for gradient accumulation ?

Your contribution

The following jax code is modified from original huggingface version

    batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps

    # add gradient accumulation
    if training_args.gradient_accumulation_steps > 1:
        optimizer = optax.chain(
            optax.apply_every(batch_size_per_update), optimizer
        )
    # Setup train state
    state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
        if len(accumulated_gradients) < training_args.gradient_accumulation_steps:
            accumulated_gradients.append(grad)
            new_state = state
        else:
            grad = jax.tree_multimap(lambda *x: jnp.sum(jnp.stack(x), axis=0), *accumulated_gradients)
            new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
            accumulated_gradients = []

buttercutter avatar Dec 21 '22 04:12 buttercutter

cc @sanchit-gandhi

sgugger avatar Dec 21 '22 07:12 sgugger

Hey @buttercutter! It looks like there are two different feature requests going on here! Let's focus on the JAX gradient accumulation one since this more relevant to the 'motivation' and code snippet you've provided. Feel free to open a separate issue for DeepSpeed activation checkpointing.

Unfortunately, gradient accumulation in JAX isn't as straightforward as using optax.apply_every! If you dig through the source code, you'll actually find that using apply_every with a batch size of N/2 and 2 accumulation steps is not necessarily equivalent to not using apply_every with a batch size of N. See https://optax.readthedocs.io/en/latest/api.html#optax.apply_every

There is an alternative in optax.MultiSteps: https://optax.readthedocs.io/en/latest/api.html#optax.MultiSteps. This will give correct gradient equivalence between using gradient accumulation and not using gradient accumulation. However in my experiments, I found it to be not super memory efficient, and consequently quite an unreliable means of using gradient accumulation. For this reason, I took the decision not to add it to the examples scripts.

Feel free to experiment with using optax.MultiSteps in your code! If you're able to get nice performance, we can explore adding it to the examples scripts! It'd be cool to benchmark the maximum permissible batch size you get without gradient accumulation, and then the maximum effective batch size you get with gradient accumulation!

In my experiments, the most memory efficient way of implementing gradient accumulation was to to write a custom loop: https://github.com/sanchit-gandhi/seq2seq-speech/blob/669e51452c396b3b8605c9ac7511da8abe31038f/run_flax_speech_recognition_seq2seq.py#L1352 Now while this is the most memory efficient way, it's the most complicated in terms of code understanding! For this reason, it's also not a good fit for the Transformers examples scripts, which we try and keep as clean and lightweight as possible.

sanchit-gandhi avatar Dec 21 '22 11:12 sanchit-gandhi

I am using your custom loop for train_step(), but I have the following error:

Note: In my code, training_args.per_device_gradient_accumulation_steps = 10 , and training_args.per_device_train_batch_size = 8 and batch has shape of (8, 3600)

Traceback (most recent call last):
  File "run_summarization_flax.py", line 1338, in <module>
    main()
  File "run_summarization_flax.py", line 1264, in main
    state, train_metric = p_train_step(state, batch)
  File "/home/moe/.local/lib/python3.8/site-packages/chex/_src/fake.py", line 175, in wrapped_fn
    output = vmapped_fn(*call_args)
  File "run_summarization_flax.py", line 1173, in train_step
    batch = jax.tree_map(
  File "run_summarization_flax.py", line 1174, in <lambda>
    lambda x: x.reshape(
  File "/home/moe/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 793, in _reshape
    return lax.reshape(a, newshape, None)
jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (8, 8, 3600) and (8, 10, 8, 3600)

buttercutter avatar Dec 23 '22 04:12 buttercutter

@sanchit-gandhi When I run your original python script without any modifications, it gave free(): invalid pointer ?

And when I use run_librispeech.sh , it gave similar error on free()again.

sh run_librispeech.sh 
src/tcmalloc.cc:332] Attempt to free invalid pointer 0x7fc48dd90558 
Aborted (core dumped)

buttercutter avatar Dec 26 '22 02:12 buttercutter

@sanchit-gandhi I am not able to use your original python script, hence I proceed with my own python script with the following slight modification to get it past the dimension runtime error.

Note that the -1 in the reshape operation means that the size of the last dimension will be inferred from the size of x and the other dimensions. Hence the following modification will reshape batch to have shape (8, 10, 3600)

# add a first dimension over gradient_accumulation_steps for minibatch slices
batch = jax.tree_map(
    lambda x: x.reshape(
        training_args.per_device_train_batch_size, training_args.per_device_gradient_accumulation_steps, -1 #*x.shape[1::]
    ),
    batch,
)

buttercutter avatar Dec 26 '22 05:12 buttercutter

Hey @buttercutter! Sorry for the late reply here!

The shape mismatch error you are experiencing is likely due to a difference in the number of accelerator devices. I purposed my script for a TPU v3-8 (8 devices), whereas it looks like you're testing on a single GPU (1 device).

With multiple devices, we shard the data across devices by prepending an extra dimension to the start of the data: (num_devices, per_device_train_batch_size, input_shape).

We don't get this extra dimension with one device: since we run everything on a single GPU, there is no need for any data sharding. This is probably the reason for the shape mis-match we are seeing here (your data is of shape (per_device_train_batch_size, input_shape)). The workaround with setting -1 in the reshape operation looks valid in this case!

Glad to see the script is working now! Let me know if you encounter any further issues - more than happy to help here!

sanchit-gandhi avatar Jan 05 '23 13:01 sanchit-gandhi

I have turned off training_args.gradient_checkpointing option for now because of the following runtime error. Could you also help to advise on this as well ?

All the weights of FlaxLongT5ForConditionalGeneration were initialized from the model checkpoint at google/long-t5-tglobal-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use FlaxLongT5ForConditionalGeneration for predictions without further training.
Traceback (most recent call last):
  File "run_summarization_flax.py", line 1340, in <module>
    main()
  File "run_summarization_flax.py", line 605, in main
    model.enable_gradient_checkpointing()
AttributeError: 'FlaxLongT5ForConditionalGeneration' object has no attribute 'enable_gradient_checkpointing'

buttercutter avatar Jan 09 '23 05:01 buttercutter

It seems that AttributeError: 'FlaxLongT5ForConditionalGeneration' object has no attribute 'enable_gradient_checkpointing' is gone after forced reinstall of transformers library.

The only issue left is the gradient accumulation

buttercutter avatar Jan 11 '23 02:01 buttercutter

@sanchit-gandhi these code changes at least bypass the gradient accumulation runtime error for now.

image

buttercutter avatar Jan 11 '23 03:01 buttercutter

Hey @buttercutter,

For such specific questions, it really helps to provide a reproducible code-snippet, such that the maintainer looking into the issue can replicate the error being faced and dig into the code on their end locally.

In this case, I created one that uses a 'tiny random' version of the BART model so that the forward/backward passes are fast, and a 'mini' version of the XSUM dataset such that the dataset download and preparation time is small:

python run_summarization_flax.py \
	--output_dir="./"  \
	--model_name_or_path="sshleifer/bart-tiny-random" \
	--tokenizer_name="sshleifer/bart-tiny-random" \
	--dataset_name="iohadrubin/mini_xsum" \
	--do_train \
       	--do_eval \
	--predict_with_generate \
	--per_device_train_batch_size 8 \
	--per_device_eval_batch_size 8 \
	--overwrite_output_dir \
	--max_source_length="64" \
       	--max_target_length 32 \ 

I would highly recommend this approach of using tiny/mini versions of the model/dataset when debugging to give a fast feedback loop! Having tiny/mini versions is also good practice when sharing your code, as it allows others to try the code out locally without enormous download and wait times.

The easiest thing to do would be to remove all the layer/grad norm logs if you don't need them (L1208-1225). Otherwise, you can follow this fix.

Upon inspection, the keys for the layer_grad_norm and layer_param_norm need to be changed for the BART model to include an extra key. The layer grad norm values then need to be made into a jnp.array:

        logs = {
            "layer_grad_norm": layer_grad_norm,
-           "encoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["encoder"])),
+           "encoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["encoder"]))),
-           "decoder_grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm["decoder"])),
+           "decoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["decoder"]))),
        }

Here's the full corrected code snippet:

        # compute gradient norms over all layers, total encoder, total decoder and global for detailed monitoring
        layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
        logs = {
            "layer_grad_norm": layer_grad_norm,
            "encoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["encoder"]))),
            "decoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["decoder"]))),
        }
        logs["grad_norm"] = jnp.linalg.norm([logs["encoder_grad_norm"], logs["decoder_grad_norm"]])

        # compute parameter norms over all layers, total encoder, total decoder and global for detailed monitoring
        layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
        logs["layer_param_norm"] = layer_param_norm
        logs["encoder_param_norm"] = jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_param_norm["model"]["encoder"])))
        logs["decoder_param_norm"] = jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_param_norm["model"]["decoder"])))
        logs["param_norm"] = jnp.linalg.norm([logs["encoder_param_norm"], logs["decoder_param_norm"]])

Hope that helps!

sanchit-gandhi avatar Jan 13 '23 18:01 sanchit-gandhi

@sanchit-gandhi

model key seems not found ?

Let me also do some debugging at the same time.

Traceback (most recent call last):
  File "run_summarization_flax.py", line 1341, in <module>
    main()
  File "run_summarization_flax.py", line 1270, in main
    state, train_metric = p_train_step(state, batch)
  File "/home/moe/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/moe/.local/lib/python3.8/site-packages/jax/_src/api.py", line 2253, in cache_miss
    execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
  File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 974, in xla_pmap_impl_lazy
    compiled_fun, fingerprint = parallel_callable(
  File "/home/moe/.local/lib/python3.8/site-packages/jax/linear_util.py", line 303, in memoized_fun
    ans = call(fun, *args)
  File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1245, in parallel_callable
    pmap_computation = lower_parallel_callable(
  File "/home/moe/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1414, in lower_parallel_callable
    jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
  File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1321, in stage_parallel_callable
    jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
  File "/home/moe/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2065, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/moe/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/moe/.local/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "run_summarization_flax.py", line 1214, in train_step
    "encoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["encoder"]))),
jax._src.traceback_util.UnfilteredStackTrace: KeyError: 'model'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "run_summarization_flax.py", line 1341, in <module>
    main()
  File "run_summarization_flax.py", line 1270, in main
    state, train_metric = p_train_step(state, batch)
  File "run_summarization_flax.py", line 1214, in train_step
    "encoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["model"]["encoder"]))),
KeyError: 'model'

buttercutter avatar Jan 17 '23 01:01 buttercutter

@sanchit-gandhi I did a print on layer_grad_norm, and it seems that model is not one of the key.

Could you advise ?

layer_grad_norm =  {'decoder': {'block': {'0': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'relative_attention_bias': {'embedding': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '1': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '10': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '11': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '2': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '3': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '4': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '5': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '6': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '7': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '8': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '9': {'layer': {'0': {'SelfAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'EncDecAttention': {'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '2': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}}, 'final_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'encoder': {'block': {'0': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'global_relative_attention_bias': {'embedding': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'relative_attention_bias': {'embedding': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '1': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '10': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '11': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '2': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '3': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '4': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '5': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '6': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '7': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '8': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}, '9': {'layer': {'0': {'TransientGlobalSelfAttention': {'global_input_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'k': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'o': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'q': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'v': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, '1': {'DenseReluDense': {'wi_0': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wi_1': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'wo': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}}}}, 'final_layer_norm': {'weight': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}, 'lm_head': {'kernel': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}, 'shared': {'embedding': 
Traced<ShapedArray(bfloat16[])>with<DynamicJaxprTrace(level=0/1)>}}

buttercutter avatar Jan 17 '23 03:01 buttercutter

Hey @buttercutter,

Unless you're really keen for grad/param norms and have your logger set-up for this, the cleanest thing to do would be to strip the grad/param norm code out of the train step. Otherwise it adds unnecessary computations for results that you won't be analysing!

I can't reproduce your code snippet, but it looks like the model you're using has one less model key in its params than the dummy one from my code snippet. If you're set on keeping the logging code in, we need to update the dict references accordingly:

        # compute gradient norms over all layers, total encoder, total decoder and global for detailed monitoring
        layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad)
        logs = {
            "layer_grad_norm": layer_grad_norm,
            "encoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["encoder"]))),
            "decoder_grad_norm": jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_grad_norm["decoder"]))),
        }
        logs["grad_norm"] = jnp.linalg.norm(jnp.array([logs["encoder_grad_norm"], logs["decoder_grad_norm"]]))

        # compute parameter norms over all layers, total encoder, total decoder and global for detailed monitoring
        layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params)
        logs["layer_param_norm"] = layer_param_norm
        logs["encoder_param_norm"] = jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_param_norm["encoder"])))
        logs["decoder_param_norm"] = jnp.linalg.norm(jnp.array(jax.tree_util.tree_leaves(layer_param_norm["decoder"])))
        logs["param_norm"] = jnp.linalg.norm(jnp.array([logs["encoder_param_norm"], logs["decoder_param_norm"]]))

sanchit-gandhi avatar Jan 19 '23 15:01 sanchit-gandhi

@sanchit-gandhi

I just confirmed that the suggested code changes to properly include logs["grad_norm"] and logs["param_norm"] actually caused OOM error on TPU.

Epoch ... (1/16):   0%|          | 0/16 [07:05<?, ?it/s]
Traceback (most recent call last):
  File "run_summarization_flax.py", line 1339, in <module>
    main()
  File "run_summarization_flax.py", line 1268, in main
    state, train_metric = p_train_step(state, batch)
ValueError: RESOURCE_EXHAUSTED: Attempting to allocate 382.18M. That was not possible. There are 375.16M free.; (0x0x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

buttercutter avatar Jan 19 '23 17:01 buttercutter

That's probably because training is working now and we're managing to run the script past the previous error no? As mentioned, feel free to remove all the logger code if you're not interested in tracking param/grad norms (this will save you a bit of memory).

Then you can try reducing your per_device_train_batch_size by factors of 2 and increasing gradient_accumulation_steps to compensate (i.e. try halving per_device_train_batch_size and doubling gradient_accumulation_steps until you can run the script without OOMs). We're now into the classic phase of finding a suitable training batch size for our model and accelerator device

sanchit-gandhi avatar Jan 19 '23 17:01 sanchit-gandhi

@sanchit-gandhi

I had reduced to even the smallest possible value for per_device_gradient_accumulation_steps=2 with per_device_train_batch_size=1, but it still give memory resource exhaustion OOM error.

Note: Removing all the logger code you provided earlier cleared this OOM error though.

buttercutter avatar Jan 20 '23 03:01 buttercutter

Hey @buttercutter! Awesome, if gradient accumulation is working without the logging code it sounds like we're in a good position 🚀 I'll close this issue unless there's anything else regarding grad accumulation you wanted to ask!

sanchit-gandhi avatar Jan 24 '23 11:01 sanchit-gandhi

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Feb 17 '23 15:02 github-actions[bot]