Patrick Kidger

Results 267 comments of Patrick Kidger

Nicely found! So unfortunately, I don't think this is something that can be fixed in Equinox/Diffrax. The memory usage here is due to `jax.jit` (not `equinox.filter_jit` or `equinox.compile_utils.compile_cache`) producing a...

Yep -- the `filter_spec` argument to `tree_deserialise_leaves` may be a PyTree that is a prefix of the one being loaded, and so you can pass a custom `filter_spec` that does/doesn't...

Yeah, that sounds reasonable. Although PyTorch and Equinox serialise in pretty different ways, I'm not sure what a partial match might mean here. That one PyTree is a prefix of...

This is surprising tricky. I've not managed to figure out a way to obtain the largest common prefix, using just JAX's public API.

Yep, go for it. I think both partial deserialisation and user extensibility are worth handling. See also the discussion in https://github.com/google/jax/issues/11210 for context. I'd be happy to contemplate a larger...

I think having a class like this is probably overkill, and would introduce too much user complexity for a relatively niche feature. At least for the problem of user-extensibility, I'm...

Actually, I think at the moment if `like` is a prefix of the saved PyTree then the deserialisation will just go wrong. If the original PyTree has `n` leaves and...

Right -- so this doesn't count as a PyTree prefix, and the fact that that works here is weird (and something we should think about handling better). Given a PyTree...

Something I spotted in parallel alongside this discussion: we're currently using `jax.tree_map` to do de/serialisation. It turns out that this doesn't actually have a guarantee on the order in which...

Right -- so the answer is that whilst `tree_map` just-so-happens to be ordered, this is actually an implementation detail. That is, JAX reserves the right to change this behaviour about...