funsor icon indicating copy to clipboard operation
funsor copied to clipboard

Funsor function that can accept varied number of Bound variables

Open ordabayevy opened this issue 3 years ago • 1 comments

mean, var, and standardize functions in 3.1.6 Normalization layers accept multiple named axes (e.g., two batch,layer axes in BatchNorm and one layer axis in LayerNorm). How can I define Mean and Standardize below so that they can accept different number of bound variables?

@make_funsor
def Mean(
    X: Funsor,
    ax: Bound
) -> Fresh[lambda X: X]:
    return X.reduce(ops.add, ax) / ax.output.size

@make_funsor
def Standardize(
    X: Funsor,
    ax: Bound
) -> Fresh[lambda X: X]:
    return (X - Mean(X, ax)) / (Variance(X, ax) + ops.finfo(X.data).eps).sqrt()

ordabayevy avatar Mar 07 '21 21:03 ordabayevy

This is not (yet) possible with the current implementation of make_funsor, but we'll need something like this if we want to rewrite more of funsor.terms with make_funsor.

A minimal solution would be to define a BoundSet hint

BoundSet = typing.FrozenSet[Bound]

and hard-code support for BoundSet inside make_funsor. Then we could write e.g.

@make_funsor
def Mean(
    X: Funsor,
    axes: BoundSet
) -> Fresh[lambda X: X]:
    return X.reduce(ops.add, axes) / reduce(ops.mul, [ax.output.size for ax in axes])

eb8680 avatar Mar 07 '21 23:03 eb8680