kfac-jax
kfac-jax copied to clipboard
Unexpected change in parameter registration from 'Orphan' to 'Generic' causes training failure in kfac_jax 0.0.7
When upgrading from kfac_jax==0.0.6 (with jax==0.4.35) to kfac_jax==0.0.7 (with jax==0.4.36), some parameters that were previously registered as Orphan are now being registered as Generic. This change leads to a training crash. In version 0.0.6, training proceeded normally even when parameters were marked as Orphan.
Expected behavior
Either:
- The behavior in 0.0.6 should be preserved (i.e., training with
Orphanparameters should continue to work), or - If the new behavior in 0.0.7 is intentional, there should be clear documentation on:
- The distinction between
'Orphan'and'Generic' - How these affect the optimizer's behavior
- The distinction between
Additional context / Questions
-
What is the intended difference between
'Orphan'and'Generic'parameter types inkfac_jax?
The documentation and source code comments do not clearly explain this. -
What internal logic determines whether a parameter is assigned
'Orphan'vs.'Generic'?
Was this logic modified between 0.0.6 and 0.0.7?
Environment
-
kfac_jax: 0.0.6 → 0.0.7 -
jax: 0.4.35 → 0.4.36 - Hardware: A100 GPU