Patrick Kidger

Results 267 comments of Patrick Kidger

Ah, that's a nice idea for a tensor detail. Yes, that should be completely possible. Quick mock-up (untested): ```python class _RequiresGradDetail(TensorDetail): def check(self, tensor: Tensor) -> bool: return tensor.requires_grad def...

Declaring the fields in advance is a syntax that we inherit from dataclasses. (Each `eqx.Module` is a dataclass.) Likewise, dataclasses don't allow adding additional fields at runtime. In principle Equinox...

> jax arrays that are not "parameters" JAX arrays can never be `static_field`s. A static field must be hashable, as it's used to form the cache key with `jax.jit`. (=you...

In JAX, a PyTree consists of two parts: the tree structure, and the leaves. For example, the PyTree `["hi", 2, (jnp.array(3.),)]` has structure `[*, *, (*,)]`, and leaves `"hi", 2,...

> If you want to have as above but with b static in `__call__` but not use the `static_field`. What is you're looking to accomplish here, precisely? "Static" basically just...

> I'm not sure if there is a more standard way of creating the filter spec that directly reaching for `object.__setattr__`. Yup, use [`eqx.tree_at`](https://docs.kidger.site/equinox/api/manipulation/#equinox.tree_at) (which is also used in the...

Can you: (a) construct a `like` that is of the correct structure (by creating a dummy `Model(extend=False)`), and use that to deserialise; (b) construct a model that is of the...

The implementation is now available here: https://github.com/patrick-kidger/equinox/blob/main/examples/unet.py

Closing with #573, as I think this is now resolved.

Closing as I think this should be resolved now. We no longer use the old `host_callback` mechanism, and `equinox.internal.branched_error_if` should now work robustly on all backends.