optax icon indicating copy to clipboard operation
optax copied to clipboard

`reduce_on_plateau` causes JIT cache miss due to weak-typed `avg_value`

Open Mycroft-47 opened this issue 2 months ago • 1 comments

I'm hitting a JIT cache miss on the second training step when using reduce_on_plateau in an optimizer chain (AdamW -> clip_by_global_norm -> reduce_on_plateau).

From JAX logs:

18:02:01 - WARNING - TRACING CACHE MISS at /tmp/ipython-input-3883789314.py:1204:67 (train) costing 416.977 ms because:
  for train_step defined at /tmp/ipython-input-3883789314.py:1102
  all previously seen cache keys are different. Closest previous key:
  * key with different input types:
      types now: params['nn']['W'][0]: f32[1,256], params['nn']['W'][1]: f32[256,256], par...
        * at opt_state[2].avg_value, now f32[]{weak_type=False} and before f32[]{weak_type=True}           <-- here
    where weak_type=True often means a Python builtin numeric value, and 
    weak_type=False means a jax.Array.

As you can see the avg_value field changes from weak-typed to strong-typed between the first and second optimizer updates, which invalidates the cache and triggers recompilation (costs ~8 seconds in my case).

I found that in optax/contrib/_reduce_on_plateau.py, avg_value gets initialized without explicit dtype in both init_fn (line 110) and _update_scale (line 159): avg_value=jnp.asarray(0.0)

After some digging, I saw that previous versions had explicit dtype but commit 25f870b (Sep 23, 2024) removed the jnp.float32 to match parameter dtypes. The scale field now uses params_dtype, but avg_value was left untyped.

What was the reasoning for leaving avg_value without dtype? Could it also use params_dtype to prevent the weak->strong type transition?

Mycroft-47 avatar Oct 24 '25 19:10 Mycroft-47

This is definitely bad! Thanks for catching it.

I'm trying to address this with: https://github.com/google-deepmind/optax/pull/1441

rdyro avatar Oct 24 '25 22:10 rdyro