Patrick Kidger
Patrick Kidger
Ah, that's a nice idea for a tensor detail. Yes, that should be completely possible. Quick mock-up (untested): ```python class _RequiresGradDetail(TensorDetail): def check(self, tensor: Tensor) -> bool: return tensor.requires_grad def...
Declaring the fields in advance is a syntax that we inherit from dataclasses. (Each `eqx.Module` is a dataclass.) Likewise, dataclasses don't allow adding additional fields at runtime. In principle Equinox...
> jax arrays that are not "parameters" JAX arrays can never be `static_field`s. A static field must be hashable, as it's used to form the cache key with `jax.jit`. (=you...
In JAX, a PyTree consists of two parts: the tree structure, and the leaves. For example, the PyTree `["hi", 2, (jnp.array(3.),)]` has structure `[*, *, (*,)]`, and leaves `"hi", 2,...
> If you want to have as above but with b static in `__call__` but not use the `static_field`. What is you're looking to accomplish here, precisely? "Static" basically just...
> I'm not sure if there is a more standard way of creating the filter spec that directly reaching for `object.__setattr__`. Yup, use [`eqx.tree_at`](https://docs.kidger.site/equinox/api/manipulation/#equinox.tree_at) (which is also used in the...
Can you: (a) construct a `like` that is of the correct structure (by creating a dummy `Model(extend=False)`), and use that to deserialise; (b) construct a model that is of the...
The implementation is now available here: https://github.com/patrick-kidger/equinox/blob/main/examples/unet.py
Closing with #573, as I think this is now resolved.
Closing as I think this should be resolved now. We no longer use the old `host_callback` mechanism, and `equinox.internal.branched_error_if` should now work robustly on all backends.