equinox icon indicating copy to clipboard operation
equinox copied to clipboard

`filter_shard` test crashes on JAX 0.4.29

Open patrick-kidger opened this issue 3 weeks ago • 4 comments

Looks like JAX used to do some "broadcasting" here, and no longer does. Bearing in mind that a PyTree may have arrays of multiple ranks, I'm not immediately sure what the appropriate fix is.

Tagging @homerjed, who originally contributed this; also @dlwh for visibility. What do you think?

patrick-kidger avatar Jun 12 '24 07:06 patrick-kidger