equinox icon indicating copy to clipboard operation
equinox copied to clipboard

lax.scan for equinox Modules ?

Open mchagneux opened this issue 10 months ago • 9 comments

Hi all,

I've create Equinox modules and I want to process them with jax.lax.scan, but I don't see any equinox.filter_scan. Is there a particular reason for that ?

Here's some example code that doesn't work at out of the box

def transform_kernel(carry, x):
    old_kernel = carry
    offset = x 
    new_mapping_bias = old_kernel.mapping._params['b'] + offset
    new_params = {'mapping':{'w':old_kernel.mapping._params['w'], 'b':new_mapping_bias}, 
                    'noise':{'loc':old_kernel.noise_dist.base_dist.loc, 
                              'log_std':jnp.log(old_kernel.noise_dist.base_dist.scale)}}
    
    new_kernel = LinearGaussianKernel(new_params) 
    return new_kernel, new_kernel




last_kernel, all_kernels = jax.lax.scan(transform_kernel, 
                                        init=some_kernel, 
                                        xs=jnp.ones((10,d_x)))

Here a LinearGaussianKernel is an Equinox module with three fields, two of them being functions. As expected I get a error of the form TypeError: Value <function LinearGaussianKernel.__init__.<locals>.<lambda> at 0x7b3530698280> with type <class 'function'> is not a valid JAX type.

This is the same error I get when using jax.vmap instead of equinox.filter_vmap. What's the workaround for jax.lax.scan, should I be using equinox.partition and equinox.combine and create my own filter_scan or is there something more canonical?

Thanks in advance.

mchagneux avatar Apr 19 '24 13:04 mchagneux