Patrick Kidger
Patrick Kidger
> The problem, as far as I can tell, is that for typecheckers `torch.Tensor` and something like `Float32[torch.Tensor, "4"]` are two different types, even though they're the same at runtime....
Yup, `stop_gradient` is the way to do this :)
I think it's back :) Sometimes it will be down for an hour or so for server maintenance.
As the others have noted, this is just `jax.lax.scan` being weird I'm afraid :)
It should work on methods the same way it works on regular functions, and `self` isn't privileged. Thus, any parts of the pytree structure of `self` that are arrays are...
Okay, so! I think we can change this. First of all, norm is assigned here: https://github.com/patrick-kidger/equinox/blob/8191b113df5d985720e86c0d6292bceb711cbe94/equinox/nn/_weight_norm.py#L91 and I assume the fact that it is dynamically creating new partials means that...
Can you try this with different versions of JAX? This seems like an underlying JAX issue rather than an Equinox one, and perhaps it only triggers in certain versions.
Some kind of cyclic garbage during our jit-dispatch seems like a plausible reason, I agree. It should probably be possible to identify this with enough `gc` magic.
Not that I ever use myself. I'm sure something must exist though! So if you find something please report back :) There are various pytree pretty-printers but it sounds like...
You want `jax.lax.stop_gradient`. This will block autodiff from operating through that variable. `eqx.field(static=True)` does something totally different -- it marks a dataclass field as not being part of the pytree...