equinox icon indicating copy to clipboard operation
equinox copied to clipboard

BatchNorm axis_name should be static?

Open hyu2000 opened this issue 4 months ago • 4 comments

When I try: model = jax.device_put_replicated(nn, devices) It would complain the BatchNorm axis_name is not a jax type.

Should we make axis_name static?

class BatchNorm(StatefulLayer): axis_name: Hashable | Sequence[Hashable] = field(static=True)

hyu2000 avatar Aug 13 '25 19:08 hyu2000

Does this only happen with batch norm? I would assume JAX wants the pytree to be JAX arrays, and thus my go to would be to filter /combine (as opposed to marking things static).

lockwo avatar Aug 13 '25 19:08 lockwo

My model is a simple ConvNet, so it's batch norm only so far.

Yes partition/combine would work, but that's a little clunky. In the current code many other properties are marked static, e.g. channelwise_affine: bool = field(static=True) momentum: float = field(static=True) mode: Literal["ema", "batch"] = field(static=True)

so I guess it makes sense to treat axis_name the same way?

hyu2000 avatar Aug 13 '25 20:08 hyu2000

This seems reasonable to me, but it would be a slightly breaking change, in that it does mean that this field can no longer be replaced using eqx.tree_at.

(Not the end of the world, could still be handled by using tree_at as far as the BatchNorm object and then using dataclasses.replace.)

No strong feelings either way from me I think.

patrick-kidger avatar Aug 14 '25 07:08 patrick-kidger

ah, I didn't know the implication of this change. I can definintely go with the partition/combine route.

hyu2000 avatar Aug 15 '25 14:08 hyu2000