funsor
funsor copied to clipboard
What is the type hint of variables bound inside (e.g., `ax` in Softmax)
trafficstars
What should be the type hint for ax in this Softmax function?
Funsorworks but then thenaxcannot be astrValue[str]can only be astrand notVariableBoundworks both asstrandVariablebutaxis not actually bound
@make_funsor
def Softmax(
x: Funsor,
ax: ? # Funsor or Bound or Value[str]
) -> Fresh[lambda x: x]:
return x.exp() / x.exp().reduce(ops.add, ax)
You'll need ax: Bound and ax2: Fresh[lambda ax: ax]. As we discussed yesterday, @make_funsor currently does not support binding-and-return of a dimension with the same name, so for now you'll need to rename the dimension. I believe @eb8680 is planning to address that as part of alpha-renaming work.
@make_funsor
def Softmax(
x: Funsor,
ax: Bound,
ax2: Fresh[lambda ax: ax],
) -> Fresh[lambda x: x]:
x = x(**{ax: ax2}) # rename to work around alpha renaming limitations
y = x - x.reduce(ops.logaddexp, ax2)
return y.exp()