relax
relax copied to clipboard
[DISCUSS] Type/shape invariants
This thread outlines the list of type and shape invariants in Relax, and hopes to generate discussions about checked_type_ of call_packed
, introducing ObjectType
, and adding ret_shape
to relax::Function
.
General invariants
- All Expr should have not null
checked_type_
after type inference. - If
checked_type_
of an Expr is nullptr, it always means that its type has not been deduced yet. - If
shape_
of an Expr with DynTensorType is nullptr, it always means that its shape has not been deduced yet.Expr->shape_
is optional because not all Expr have shape, for example, a value withShapeType
does not have shape, so we only do shape deduction for Expr withDynTensorType
. - After deduction,
shape_
of an Expr with DynTensorType is always not nullptr, since it can be generally assigned asRuntimeDepShape
. - During type deduction, every Expr's op should have not null
checked_type_
. For example, when deducing the type ofTupleGetItem
, checkTupleGetItem->tuple
's checked_type_ isTupleType
.
call_packed
In most cases we know the return type of a packed function. For example, if a packed function performs a tensor operation, the return value should be of DynTensorType
. We can annotate the return type by type_args
, and we can follow it by match_shape
to match the return shape to symbolic variables.
In TVMScript:
x = relax.call_packed("test.vm.mul", x, y, type_args=(Tensor(rank=2, dtype="float32")))
If type_args
is not annotated, return ObjectType
, which serves as a general type. Their runtime representations can be tvm::runtime::Object
such as String
, Array
, Integer
, etc.
type/shape deduction for Expr
If-then-else
If two branches gives different types, we should find the lowest common ancestor of them.
-
LCA(DynTensorType[ndim=2, "f32"], DynTensorType[ndim=3, "f32"]) = DynTensorType[ndim=-1, "f32")
-
LCA(DynTensorType[ndim=2, "f32"], ShapeType) = ObjectType
Tuple
- Tuple→field's
checked_type_
must not be null. - Tuple’s
shape_
can be null, for exampleTuple[Shape, DynTensor, Object]
. When all the fields of a Tuple are of DynTensorType,shape_
must not be null.
Necessity of explicit return type and shape annotation in Function
We cannot always deduce the type/shape of a Function. Imagine we have the mutual function calls as below, we cannot deduce the auto
type, so we need manual annotation of ret_type
and possibly ret_shape
.
auto f1(int y) {
return f2(y+1);
}
auto f2(int x) {
return f1(x+1);
}
Also in recursive case:
def f(x: Tensor[(3, 4), "float32"]):
if x < 1:
return x
else:
return f(x-1) + 2
We cannot deduce the type/shape of f(x-1)
, because we do not know the return type of f
, so there is a cyclic dependency:
- In order to deduce return type/shape of
f
, we need to walk through its body and do deduction; - In order to do deduction of
f(x-1)
we want to know the return type/shape of f.
to break this chain , we do instead
def f(x: Tensor[(3, 4), "float32"]) -> Tensor[(3, 4), "float32"]:
if x < 1:
return x
else:
return f(x-1) + 2
And in f(x-1)
, we only need to query the signature(ret_type
and ret_shape
) of f
, without depending on its body to know the return type.