`reduce_on_plateau` causes JIT cache miss due to weak-typed `avg_value`
I'm hitting a JIT cache miss on the second training step when using reduce_on_plateau in an optimizer chain (AdamW -> clip_by_global_norm -> reduce_on_plateau).
From JAX logs:
18:02:01 - WARNING - TRACING CACHE MISS at /tmp/ipython-input-3883789314.py:1204:67 (train) costing 416.977 ms because:
for train_step defined at /tmp/ipython-input-3883789314.py:1102
all previously seen cache keys are different. Closest previous key:
* key with different input types:
types now: params['nn']['W'][0]: f32[1,256], params['nn']['W'][1]: f32[256,256], par...
* at opt_state[2].avg_value, now f32[]{weak_type=False} and before f32[]{weak_type=True} <-- here
where weak_type=True often means a Python builtin numeric value, and
weak_type=False means a jax.Array.
As you can see the avg_value field changes from weak-typed to strong-typed between the first and second optimizer updates, which invalidates the cache and triggers recompilation (costs ~8 seconds in my case).
I found that in optax/contrib/_reduce_on_plateau.py, avg_value gets initialized without explicit dtype in both init_fn (line 110) and _update_scale (line 159): avg_value=jnp.asarray(0.0)
After some digging, I saw that previous versions had explicit dtype but commit 25f870b (Sep 23, 2024) removed the jnp.float32 to match parameter dtypes. The scale field now uses params_dtype, but avg_value was left untyped.
What was the reasoning for leaving avg_value without dtype? Could it also use params_dtype to prevent the weak->strong type transition?
This is definitely bad! Thanks for catching it.
I'm trying to address this with: https://github.com/google-deepmind/optax/pull/1441