kfac-jax icon indicating copy to clipboard operation
kfac-jax copied to clipboard

TypeError: unhashable type: 'Literal'

Open DanChai22 opened this issue 1 year ago • 0 comments

Issue Report: TypeError with Constant Multiplication in network When Using kfac_jax

Summary

When using kfac_jax with a network, introducing a scaling factor (e.g., geo_scale) for lattice parameters in the computational graph causes a TypeError related to the use of Literal. This occurs whether geo_scale is passed as a parameter or defined as a constant multiplier in the computation.


Steps to Reproduce

Here’s an example illustrating the issue(get_jacobian is part of the network):

  1. Working Example: No Scaling
    The following works without errors:

    def get_jacobian(params):
        p_cell = params['cell'].ravel()  # No scaling applied
        return jnp.diag(p_cell)
    
  2. Failing Example 1: Direct Multiplication Adding a constant multiplier to params['cell'] causes a TypeError:

    def get_jacobian(params):
        p_cell = params['cell'].ravel() * 1e-3  # Multiplying with a constant
        return jnp.diag(p_cell)
    

    Error Raised:

    TypeError: unhashable type: 'Literal'
    
  3. Failing Example 2: Adding geo_scale Parameter
    Introducing a geo_scale parameter also causes the same TypeError:

    def get_jacobian(params, geo_scale=1e-3):
        p_cell = params['cell'].ravel() * geo_scale
        return jnp.diag(p_cell)
    

    Error Raised:

    TypeError: unhashable type: 'Literal'
    

Questions

  1. Is there a recommended approach for handling constants or scaling factors like geo_scale in kfac_jax workflows to avoid such issues?

DanChai22 avatar Jan 13 '25 07:01 DanChai22