funsor
funsor copied to clipboard
Support size variables in funsor.domains (dependent types)
Currently @funsor.torch.function
must wrap each sized matmul individually, e.g.
@function(reals(2,3), reals(3,4), reals(2,4))
def matmul_2_3_4(x, y):
return x.matmul(y)
Could we generalize this to einsum-style syntax with strings as size variables?
@function(reals("a", "b"), reals("b", "c"), reals("a", "c"))
def matmul(x, y):
return x.matmul(y)
@fritzo and I came up with what seems like a nice design in #442. Inspired by the notation in that PR, we might write the matmul
example above as
@function
def matmul(
x: Array,
y: Array
) -> Dependent[lambda x, y: Array[x.dtype, (x.shape[0], y.shape[1])]]:
return x.matmul(y)
where Dependent
, like Fresh
in #442. takes a lambda
as an argument that takes the domains of arguments x
and y
at the time that matmul
is called and returns the result domain.
One question for this approach is whether to allow more specific dependent type annotations for x
and y
analogous to the shape variables a,b,c
above and how useful that would be for shape checking.
@eb8680 interesting... I guess we could even enforce constraints, something like
def assume(constraint, value):
if not constraint:
raise TypeError
return value
@function
def matmul(
x: Dependent[lambda x: assume(isinstance(x, Array) and len(x.shape) == 2, x)],
y: Dependent[lambda y: assume(isinstance(y, Array) and len(y.shape) == 2, y)],
) -> Dependent[lambda x, y: assume(x.dtype == y.dtype and x.shape[1] == y.shape[0],
Array[x.dtype, (x.shape[0], y.shape[1])])]:
return x.matmul(y)
or with a Where[...]
type.
I like the idea of having assume
or Where
, where Where
takes a base type and a dependent predicate. We could even allow their use simultaneously.
A downside of this notation is that it quickly becomes verbose when specifying complicated shape constraints. Maybe it would still be useful to have a notation like the original post specifically for Array
that desugars to Dependent
/Where
. This also has the virtue of nudging users away from specifying shape arithmetic constraints which are harder to get right or compose.
For example, we could write a batched matmul
as:
@function
def matmul(
x: Array['real', (..., "a", "b")],
y: Array['real', (..., "b", "c")],
) -> Array['real', (..., "a", "c")]:
return x.matmul(y)
which might desugar to
@function
def matmul(
x: Where[Dependent[lambda x: Array[x.dtype, x.shape]], lambda x: len(x.shape) >= 2],
y: Where[Dependent[lambda y: Array[y.dtype, y.shape]], lambda y: len(y.shape) >= 2],
) -> Where[Dependent[lambda x, y: Array['real', broadcast_shape(x.shape[:-2], y.shape[:-2]) + (x.shape[-2], y.shape[-1])]],
lambda x, y: is_broadcastable(x.shape[:-2], y.shape[:-2]) and x.shape[-1] == y.shape[-2] and x.dtype == y.dtype]: # generated from shape variables
return x.matmul(y)
One implication of such an expressive design is that only ground types could be supported in inputs and outputs of actual funsors e.g. we couldn't allow the construction of funsor.Variable(name, Dependent[...])
.
Another set of operations that pose a challenge for this design are those where shapes depend on an argument value, even when that value is guaranteed to be an integer or boolean literal. In #442, the function in Fresh
takes domains as arguments rather than values. The simplest one is .sum(dim=dim)
:
@function
def sum_one_dim(
x: Array,
dim: int
) -> Dependent[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim+1:]]]:
return x.sum(dim)
For this to work as written, lambda x, dim: ...
would have to take the value of dim
, not its type int
. This forces us to choose a consistent behavior for the other argument x
:
- Should this
lambda
take the value ofx
by default, rather than its.output
? This does not seem ideal ifx
could be a Funsor - Should we require a special notation indicating that the return type of
sum_one_dim
depends on the value ofdim
(e.g.dim: Value[int]
instead ofdim: int
)? - Should we have a different annotation
ValueDependent
for types that depend on values, e.g. the return type ofsum_one_dim
? - Alternatively, should we simply disallow value-dependent types and say that
sum_one_dim
should either use a less specific signature or be written as a Funsor term?
Example of 2:
@function
def sum_one_dim(
x: Array,
dim: Value[int]
) -> Dependent[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim+1:]]]:
return x.sum(dim)