funsor
funsor copied to clipboard
Funsor function that can accept varied number of Bound variables
mean
, var
, and standardize
functions in 3.1.6 Normalization layers accept multiple named axes (e.g., two batch,layer
axes in BatchNorm
and one layer
axis in LayerNorm
). How can I define Mean
and Standardize
below so that they can accept different number of bound variables?
@make_funsor
def Mean(
X: Funsor,
ax: Bound
) -> Fresh[lambda X: X]:
return X.reduce(ops.add, ax) / ax.output.size
@make_funsor
def Standardize(
X: Funsor,
ax: Bound
) -> Fresh[lambda X: X]:
return (X - Mean(X, ax)) / (Variance(X, ax) + ops.finfo(X.data).eps).sqrt()
This is not (yet) possible with the current implementation of make_funsor
, but we'll need something like this if we want to rewrite more of funsor.terms
with make_funsor
.
A minimal solution would be to define a BoundSet
hint
BoundSet = typing.FrozenSet[Bound]
and hard-code support for BoundSet
inside make_funsor
. Then we could write e.g.
@make_funsor
def Mean(
X: Funsor,
axes: BoundSet
) -> Fresh[lambda X: X]:
return X.reduce(ops.add, axes) / reduce(ops.mul, [ax.output.size for ax in axes])