equinox
equinox copied to clipboard
lax.scan for equinox Modules ?
Hi all,
I've create Equinox modules and I want to process them with jax.lax.scan
, but I don't see any equinox.filter_scan
. Is there a particular reason for that ?
Here's some example code that doesn't work at out of the box
def transform_kernel(carry, x):
old_kernel = carry
offset = x
new_mapping_bias = old_kernel.mapping._params['b'] + offset
new_params = {'mapping':{'w':old_kernel.mapping._params['w'], 'b':new_mapping_bias},
'noise':{'loc':old_kernel.noise_dist.base_dist.loc,
'log_std':jnp.log(old_kernel.noise_dist.base_dist.scale)}}
new_kernel = LinearGaussianKernel(new_params)
return new_kernel, new_kernel
last_kernel, all_kernels = jax.lax.scan(transform_kernel,
init=some_kernel,
xs=jnp.ones((10,d_x)))
Here a LinearGaussianKernel
is an Equinox module with three fields, two of them being functions. As expected I get a error of the form
TypeError: Value <function LinearGaussianKernel.__init__.<locals>.<lambda> at 0x7b3530698280> with type <class 'function'> is not a valid JAX type.
This is the same error I get when using jax.vmap
instead of equinox.filter_vmap
. What's the workaround for jax.lax.scan
, should I be using equinox.partition
and equinox.combine
and create my own filter_scan
or is there something more canonical?
Thanks in advance.