funsor
funsor copied to clipboard
BindReturn type hint for make_funsor
Addresses #481.
BindReturn
type hint is used for binding and returning a variable. For example:
@make_funsor
def Unroll(
x: Has[{"ax"}], # noqa: F821
ax: BindReturn[lambda ax, k: Bint[ax.size - k + 1]],
k: Value[int],
kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
return x(**{ax.name: ax + kernel})
x = random_tensor(OrderedDict(a=Bint[5]))
with reflect:
y = Unroll(x, "a", 2, "kernel")
assert y.fresh == frozenset({"a", "kernel"})
assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound)
check_funsor(y, {"a": Bint[5 - 2 + 1], "kernel": Bint[2]}, Real)
or
@make_funsor
def Softmax(
x: Has[{"ax"}], # noqa: F821
ax: BindReturn[lambda ax: ax],
) -> Fresh[lambda x: x]:
y = x - x.reduce(ops.logaddexp, ax)
return y.exp()
x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4]))
with reflect:
y = Softmax(x, "a")
assert y.fresh == frozenset({"a"})
assert all(bound in y.x.inputs and bound[1:8] == "__BOUND" for bound in y.bound)
check_funsor(y, {"a": Bint[3], "b": Bint[4]}, Real)
Thanks for adding this! WDYT about making BindReturn
the default behavior of Fresh
? That would be consistent with default behavior in existing terms like Cat
, Stack
and Independent
and avoid growing the number of special make_funsor
type annotations. The conservative alternative is having make_funsor
raise an error when a Fresh
variable appears in the inputs of another argument.
WDYT about making BindReturn the default behavior of Fresh?
Do you mean that in the example below Fresh
type hint would be smart to make ax
both bound
and fresh
and make kernel
only fresh
?
@make_funsor
def Unroll(
x: Has[{"ax"}], # noqa: F821
ax: Fresh[lambda ax, k: Bint[ax.size - k + 1]],
k: Value[int],
kernel: Fresh[lambda k: Bint[k]],
) -> Fresh[lambda x: x]:
return x(**{ax.name: ax + kernel})
Do you mean that in the example below Fresh type hint would be smart to make ax both bound and fresh and make kernel only fresh?
Yes, exactly.
On a related note, we should also start using funsor.domains.Dependent
rather than Fresh
for annotating return types of make_funsor
, but that's for another PR.
Yes, exactly.
I like the idea. I'll make the changes then.
Sorry for taking so long to review this (especially since I suggested you try it in the first place). I am still not sure how to go about fixing alpha-conversion as a whole in a way that remains compatible with cons-hashing, so I had put off thinking about details.
I think for the behavior implemented in this PR to be safe and correct by construction in general, we would need to eagerly alpha-mangle the arguments to a make_funsor
term with a bound Fresh
variable before evaluating that term with a rewrite rule to guarantee that there are no collisions between the fresh variable and the implicitly bound variable inside the rule body. In fact, we could always do this for any Funsor; the fact that we only alpha-convert in reflect
is an (important) optimization.
We could write a simple decorator for rewrite rules to perform this extra step:
def bind_args(term):
def binding_wrapper(rule):
def wrapped_rule(*args):
mangled_args = reflect.interpret(term, *args)._ast_values
return rule(*mangled_args)
return functools.wraps(rule)(wrapped_rule)
return binding_wrapper
To illustrate the use of bind_args
, in your Softmax
example, we could separate the term definition and the default reduction rule and manually apply bind_args
to the rule:
@make_funsor
def Softmax(
x: Has[{"ax"}], # noqa: F821
ax: Fresh[lambda ax: ax],
) -> Fresh[lambda x: x]:
return None
@eager.register(Softmax, Tensor, Variable)
@bind_args(Softmax)
def _eager_softmax(x, ax):
y = x - x.reduce(ops.logaddexp, ax)
return y.exp()
Of course, this is less ergonomic than the original syntax, so I could imagine folding bind_args
into make_funsor
or even into interpretations, although this would come at a considerable computational cost with the current implementation of alpha-conversion and is not necessary for this particular example.