kfac-jax
kfac-jax copied to clipboard
TypeError: unhashable type: 'Literal'
Issue Report: TypeError with Constant Multiplication in network When Using kfac_jax
Summary
When using kfac_jax with a network, introducing a scaling factor (e.g., geo_scale) for lattice parameters in the computational graph causes a TypeError related to the use of Literal. This occurs whether geo_scale is passed as a parameter or defined as a constant multiplier in the computation.
Steps to Reproduce
Here’s an example illustrating the issue(get_jacobian is part of the network):
-
Working Example: No Scaling
The following works without errors:def get_jacobian(params): p_cell = params['cell'].ravel() # No scaling applied return jnp.diag(p_cell) -
Failing Example 1: Direct Multiplication Adding a constant multiplier to
params['cell']causes aTypeError:def get_jacobian(params): p_cell = params['cell'].ravel() * 1e-3 # Multiplying with a constant return jnp.diag(p_cell)Error Raised:
TypeError: unhashable type: 'Literal' -
Failing Example 2: Adding
geo_scaleParameter
Introducing ageo_scaleparameter also causes the sameTypeError:def get_jacobian(params, geo_scale=1e-3): p_cell = params['cell'].ravel() * geo_scale return jnp.diag(p_cell)Error Raised:
TypeError: unhashable type: 'Literal'
Questions
- Is there a recommended approach for handling constants or scaling factors like
geo_scaleinkfac_jaxworkflows to avoid such issues?