funsor icon indicating copy to clipboard operation
funsor copied to clipboard

BindReturn type hint for make_funsor

Open ordabayevy opened this issue 3 years ago • 5 comments

Addresses #481.

BindReturn type hint is used for binding and returning a variable. For example:

@make_funsor
def Unroll(
    x: Has[{"ax"}],  # noqa: F821
    ax: BindReturn[lambda ax, k: Bint[ax.size - k + 1]],
    k: Value[int],
    kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
    return x(**{ax.name: ax + kernel})

x = random_tensor(OrderedDict(a=Bint[5]))
with reflect:
    y = Unroll(x, "a", 2, "kernel")
assert y.fresh == frozenset({"a", "kernel"})
assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound)
check_funsor(y, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real)

or

@make_funsor
def Softmax(
    x: Has[{"ax"}],  # noqa: F821
    ax: BindReturn[lambda ax: ax],
) -> Fresh[lambda x: x]:
    y = x - x.reduce(ops.logaddexp, ax)
    return y.exp()

x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4]))
with reflect:
    y = Softmax(x, "a")
assert y.fresh == frozenset({"a"})
assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound)
check_funsor(y, {"a": Bint[3], "b": Bint[4]}, Real)

ordabayevy avatar Apr 04 '21 02:04 ordabayevy

Thanks for adding this! WDYT about making BindReturn the default behavior of Fresh? That would be consistent with default behavior in existing terms like Cat, Stack and Independent and avoid growing the number of special make_funsor type annotations. The conservative alternative is having make_funsor raise an error when a Fresh variable appears in the inputs of another argument.

eb8680 avatar Apr 05 '21 04:04 eb8680

WDYT about making BindReturn the default behavior of Fresh?

Do you mean that in the example below Fresh type hint would be smart to make ax both bound and fresh and make kernel only fresh?

@make_funsor
def Unroll(
    x: Has[{"ax"}],  # noqa: F821
    ax: Fresh[lambda ax, k: Bint[ax.size - k + 1]],
    k: Value[int],
    kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
    return x(**{ax.name: ax + kernel})

ordabayevy avatar Apr 05 '21 04:04 ordabayevy

Do you mean that in the example below Fresh type hint would be smart to make ax both bound and fresh and make kernel only fresh?

Yes, exactly.

On a related note, we should also start using funsor.domains.Dependent rather than Fresh for annotating return types of make_funsor, but that's for another PR.

eb8680 avatar Apr 05 '21 04:04 eb8680

Yes, exactly.

I like the idea. I'll make the changes then.

ordabayevy avatar Apr 05 '21 04:04 ordabayevy

Sorry for taking so long to review this (especially since I suggested you try it in the first place). I am still not sure how to go about fixing alpha-conversion as a whole in a way that remains compatible with cons-hashing, so I had put off thinking about details.

I think for the behavior implemented in this PR to be safe and correct by construction in general, we would need to eagerly alpha-mangle the arguments to a make_funsor term with a bound Fresh variable before evaluating that term with a rewrite rule to guarantee that there are no collisions between the fresh variable and the implicitly bound variable inside the rule body. In fact, we could always do this for any Funsor; the fact that we only alpha-convert in reflect is an (important) optimization.

We could write a simple decorator for rewrite rules to perform this extra step:

def bind_args(term):

    def binding_wrapper(rule):
        def wrapped_rule(*args):
            mangled_args = reflect.interpret(term, *args)._ast_values
            return rule(*mangled_args)
        return functools.wraps(rule)(wrapped_rule)

    return binding_wrapper

To illustrate the use of bind_args, in your Softmax example, we could separate the term definition and the default reduction rule and manually apply bind_args to the rule:

@make_funsor
def Softmax(
    x: Has[{"ax"}],  # noqa: F821
    ax: Fresh[lambda ax: ax],
) -> Fresh[lambda x: x]:
    return None

@eager.register(Softmax, Tensor, Variable)
@bind_args(Softmax)
def _eager_softmax(x, ax):
    y = x - x.reduce(ops.logaddexp, ax)
    return y.exp()

Of course, this is less ergonomic than the original syntax, so I could imagine folding bind_args into make_funsor or even into interpretations, although this would come at a considerable computational cost with the current implementation of alpha-conversion and is not necessary for this particular example.

eb8680 avatar May 12 '21 23:05 eb8680