einx icon indicating copy to clipboard operation
einx copied to clipboard

vmap can't be used with equinox module as op

Open dimitriye98 opened this issue 1 year ago • 2 comments

Attempting to use einx.vmap with an equinox module as the op argument will crash if the module has any parameters (as jax arrays are unhashable).

This is particularly problematic, as vmapping over a module is quite common in idiomatic equinox code with things like the built-in Embedding module more or less requiring it.

dimitriye98 avatar Dec 22 '24 11:12 dimitriye98

After some additional investigation, a potential path towards a solution: when tracing with Jax backend, prior to hitting the cache with the input shapes and function, check if the function is a bound method. If so, check if the object it's bound to is a pytree. If so, map over the pytree replacing all tensors with their shapes, and then use this and the unbound method as the cache key instead of using the bound method. I'll try and put together a PR in the coming days if I have time.

dimitriye98 avatar Dec 22 '24 12:12 dimitriye98

Thanks for bringing this up! I don't think the proposed solution would work here though. The op argument isn't passed through to the compiled function each time einx.vmap is called since it is assumed to be static. Instead, a reference to the first op argument for a given signature is stored in the cache, and used for all subsequent calls with the same signature. This isn't a problem if op is identical in all calls, but that wouldn't be the case here.

For example, if there are different instances of the same equinox class that are passed to op, the one that einx.vmap is first called with (including the Jax arrays stored in it) will be used in all subsequent calls, even if a different module instance with different Jax arrays (but the same shapes/ signature) is actually passed in a later call.

I think the solution would be to also make op a traced argument (or at least the arrays stored in it), so that it isn't stored in the compiled function and is instead treated like any other dynamic argument.

fferflo avatar Jan 13 '25 11:01 fferflo