equinox
equinox copied to clipboard
`filter_shard` test crashes on JAX 0.4.29
Looks like JAX used to do some "broadcasting" here, and no longer does. Bearing in mind that a PyTree may have arrays of multiple ranks, I'm not immediately sure what the appropriate fix is.
Tagging @homerjed, who originally contributed this; also @dlwh for visibility. What do you think?