kfac-jax icon indicating copy to clipboard operation
kfac-jax copied to clipboard

Unexpected change in parameter registration from 'Orphan' to 'Generic' causes training failure in kfac_jax 0.0.7

Open DanChai22 opened this issue 7 months ago • 0 comments

When upgrading from kfac_jax==0.0.6 (with jax==0.4.35) to kfac_jax==0.0.7 (with jax==0.4.36), some parameters that were previously registered as Orphan are now being registered as Generic. This change leads to a training crash. In version 0.0.6, training proceeded normally even when parameters were marked as Orphan.

Expected behavior

Either:

  • The behavior in 0.0.6 should be preserved (i.e., training with Orphan parameters should continue to work), or
  • If the new behavior in 0.0.7 is intentional, there should be clear documentation on:
    • The distinction between 'Orphan' and 'Generic'
    • How these affect the optimizer's behavior

Additional context / Questions

  1. What is the intended difference between 'Orphan' and 'Generic' parameter types in kfac_jax?
    The documentation and source code comments do not clearly explain this.

  2. What internal logic determines whether a parameter is assigned 'Orphan' vs. 'Generic'?
    Was this logic modified between 0.0.6 and 0.0.7?

Environment

  • kfac_jax: 0.0.6 → 0.0.7
  • jax: 0.4.35 → 0.4.36
  • Hardware: A100 GPU

DanChai22 avatar Jun 03 '25 12:06 DanChai22