equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Replacement for `equinox.internal.if_mapped`?

Open michael-0brien opened this issue 1 year ago • 2 comments

Hi, I noticed that equinox.internal.if_mapped has been removed in a recent version of equinox rather than added to the main package. What is the reason for this / what would a reasonable replacement be to replicate its behavior?

I have highly nested pytrees in my package, so if_mapped turns out to be very useful for saving memory when moving across vmap boundaries.

michael-0brien avatar Dec 19 '24 04:12 michael-0brien

So this hooked into some JAX internals that are getting removed in a future version of JAX. You could maybe replicate the same behavior by arranging for something updated and similar yourself, but I'm not sure exactly what.

The fact that this used such internals is why it was undocumented in equinox.internal instead of the main namespace, as I knew this might need to happen at some point!

FWIW I never really found a use-case for this in my own work as it meant that the output shape was hard to track.

Sorry that I can't give you better news!

patrick-kidger avatar Dec 19 '24 08:12 patrick-kidger

This is okay, makes sense! It is probably best practice to be explicit about the out_axes anyway. The best replacement for the behavior is probably to just create an out_axes pytree explicitly from a filter_spec, like the following

out_axes = jax.tree.map(lambda x: 0 if x else None, filter_spec)

michael-0brien avatar Dec 19 '24 15:12 michael-0brien