funsor
                                
                                 funsor copied to clipboard
                                
                                    funsor copied to clipboard
                            
                            
                            
                        BindReturn type hint for make_funsor
Addresses #481.
BindReturn type hint is used for binding and returning a variable. For example:
@make_funsor
def Unroll(
    x: Has[{"ax"}],  # noqa: F821
    ax: BindReturn[lambda ax, k: Bint[ax.size - k + 1]],
    k: Value[int],
    kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
    return x(**{ax.name: ax + kernel})
x = random_tensor(OrderedDict(a=Bint[5]))
with reflect:
    y = Unroll(x, "a", 2, "kernel")
assert y.fresh == frozenset({"a", "kernel"})
assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound)
check_funsor(y, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real)
or
@make_funsor
def Softmax(
    x: Has[{"ax"}],  # noqa: F821
    ax: BindReturn[lambda ax: ax],
) -> Fresh[lambda x: x]:
    y = x - x.reduce(ops.logaddexp, ax)
    return y.exp()
x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4]))
with reflect:
    y = Softmax(x, "a")
assert y.fresh == frozenset({"a"})
assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound)
check_funsor(y, {"a": Bint[3], "b": Bint[4]}, Real)
Thanks for adding this! WDYT about making BindReturn the default behavior of Fresh? That would be consistent with default behavior in existing terms like Cat, Stack and Independent and avoid growing the number of special make_funsor type annotations. The conservative alternative is having make_funsor raise an error when a Fresh variable appears in the inputs of another argument.
WDYT about making BindReturn the default behavior of Fresh?
Do you mean that in the example below Fresh type hint would be smart to make ax both bound and fresh and make kernel only fresh?
@make_funsor
def Unroll(
    x: Has[{"ax"}],  # noqa: F821
    ax: Fresh[lambda ax, k: Bint[ax.size - k + 1]],
    k: Value[int],
    kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
    return x(**{ax.name: ax + kernel})
Do you mean that in the example below Fresh type hint would be smart to make ax both bound and fresh and make kernel only fresh?
Yes, exactly.
On a related note, we should also start using funsor.domains.Dependent rather than Fresh for annotating return types of make_funsor, but that's for another PR.
Yes, exactly.
I like the idea. I'll make the changes then.
Sorry for taking so long to review this (especially since I suggested you try it in the first place). I am still not sure how to go about fixing alpha-conversion as a whole in a way that remains compatible with cons-hashing, so I had put off thinking about details.
I think for the behavior implemented in this PR to be safe and correct by construction in general, we would need to eagerly alpha-mangle the arguments to a make_funsor term with a bound Fresh variable before evaluating that term with a rewrite rule to guarantee that there are no collisions between the fresh variable and the implicitly bound variable inside the rule body. In fact, we could always do this for any Funsor; the fact that we only alpha-convert in reflect is an (important) optimization.
We could write a simple decorator for rewrite rules to perform this extra step:
def bind_args(term):
    def binding_wrapper(rule):
        def wrapped_rule(*args):
            mangled_args = reflect.interpret(term, *args)._ast_values
            return rule(*mangled_args)
        return functools.wraps(rule)(wrapped_rule)
    return binding_wrapper
To illustrate the use of bind_args, in your Softmax example, we could separate the term definition and the default reduction rule and manually apply bind_args to the rule:
@make_funsor
def Softmax(
    x: Has[{"ax"}],  # noqa: F821
    ax: Fresh[lambda ax: ax],
) -> Fresh[lambda x: x]:
    return None
@eager.register(Softmax, Tensor, Variable)
@bind_args(Softmax)
def _eager_softmax(x, ax):
    y = x - x.reduce(ops.logaddexp, ax)
    return y.exp()
Of course, this is less ergonomic than the original syntax, so I could imagine folding bind_args into make_funsor or even into interpretations, although this would come at a considerable computational cost with the current implementation of alpha-conversion and is not necessary for this particular example.