relax icon indicating copy to clipboard operation
relax copied to clipboard

[DISCUSS] Type/shape invariants

Open YuchenJin opened this issue 2 years ago • 0 comments

This thread outlines the list of type and shape invariants in Relax, and hopes to generate discussions about checked_type_ of call_packed, introducing ObjectType, and adding ret_shape to relax::Function.

General invariants

  • All Expr should have not null checked_type_ after type inference.
  • If checked_type_ of an Expr is nullptr, it always means that its type has not been deduced yet.
  • If shape_ of an Expr with DynTensorType is nullptr, it always means that its shape has not been deduced yet. Expr->shape_ is optional because not all Expr have shape, for example, a value with ShapeType does not have shape, so we only do shape deduction for Expr with DynTensorType.
  • After deduction, shape_ of an Expr with DynTensorType is always not nullptr, since it can be generally assigned as RuntimeDepShape.
  • During type deduction, every Expr's op should have not null checked_type_. For example, when deducing the type of TupleGetItem, check TupleGetItem->tuple's checked_type_ is TupleType.

call_packed

In most cases we know the return type of a packed function. For example, if a packed function performs a tensor operation, the return value should be of DynTensorType. We can annotate the return type by type_args, and we can follow it by match_shape to match the return shape to symbolic variables.

In TVMScript: x = relax.call_packed("test.vm.mul", x, y, type_args=(Tensor(rank=2, dtype="float32")))

If type_args is not annotated, return ObjectType, which serves as a general type. Their runtime representations can be tvm::runtime::Object such as String, Array, Integer, etc.

type/shape deduction for Expr

If-then-else

If two branches gives different types, we should find the lowest common ancestor of them.

  • LCA(DynTensorType[ndim=2, "f32"], DynTensorType[ndim=3, "f32"]) = DynTensorType[ndim=-1, "f32")
  • LCA(DynTensorType[ndim=2, "f32"], ShapeType) = ObjectType

Tuple

  • Tuple→field's checked_type_ must not be null.
  • Tuple’s shape_ can be null, for example Tuple[Shape, DynTensor, Object]. When all the fields of a Tuple are of DynTensorType, shape_ must not be null.

Necessity of explicit return type and shape annotation in Function

We cannot always deduce the type/shape of a Function. Imagine we have the mutual function calls as below, we cannot deduce the auto type, so we need manual annotation of ret_type and possibly ret_shape.

auto f1(int y) {
   return f2(y+1);
}
auto f2(int x) {
   return f1(x+1);
}

Also in recursive case:

def f(x: Tensor[(3, 4), "float32"]):
   if x < 1:
      return x
   else:
      return f(x-1) + 2

We cannot deduce the type/shape of f(x-1), because we do not know the return type of f, so there is a cyclic dependency:

  • In order to deduce return type/shape of f, we need to walk through its body and do deduction;
  • In order to do deduction of f(x-1) we want to know the return type/shape of f.

to break this chain , we do instead

def f(x: Tensor[(3, 4), "float32"]) -> Tensor[(3, 4), "float32"]:
   if x < 1:
      return x
   else:
      return f(x-1) + 2

And in f(x-1), we only need to query the signature(ret_type and ret_shape) of f, without depending on its body to know the return type.

YuchenJin avatar Apr 05 '22 23:04 YuchenJin