BatchNorm axis_name should be static?
When I try: model = jax.device_put_replicated(nn, devices) It would complain the BatchNorm axis_name is not a jax type.
Should we make axis_name static?
class BatchNorm(StatefulLayer): axis_name: Hashable | Sequence[Hashable] = field(static=True)
Does this only happen with batch norm? I would assume JAX wants the pytree to be JAX arrays, and thus my go to would be to filter /combine (as opposed to marking things static).
My model is a simple ConvNet, so it's batch norm only so far.
Yes partition/combine would work, but that's a little clunky. In the current code many other properties are marked static, e.g. channelwise_affine: bool = field(static=True) momentum: float = field(static=True) mode: Literal["ema", "batch"] = field(static=True)
so I guess it makes sense to treat axis_name the same way?
This seems reasonable to me, but it would be a slightly breaking change, in that it does mean that this field can no longer be replaced using eqx.tree_at.
(Not the end of the world, could still be handled by using tree_at as far as the BatchNorm object and then using dataclasses.replace.)
No strong feelings either way from me I think.
ah, I didn't know the implication of this change. I can definintely go with the partition/combine route.