relax
relax copied to clipboard
[DISCUSS] TensorArray/TensorList support
TensorArray Support
TensorArray is commonly used with control flow to save results in the loop. When looking into the relay IR converted from TensorFlow or PyTorch, such as TensorFlow SSD, PyTorch LSTM, we found tons of TensorArray/TensorList related operations there. We plan:
- Introduce a generic object type to express TensorArray.
- Implement TensorArray/TensorList operations as external packed functions. The TensorArray/TensorList operations to be supported in table below are from the public lstm/object detection/segmentation models implemented via TensorFlow/PyTorch.
op | op description |
---|---|
TensorArraySize | current size of tensorarray |
TensorArrayWrite | write element into tensorarray |
TensorArrayRead | read element from tensorarray |
TensorArraySplit | split data from input into elements of tensorarray |
TensorArrayScatter | scatter data from input into specific elements |
TensorArrayConcat | concatenate elements |
TensorListFromTensor | create tensorarray from tensor |
TensorListGetItem | get specific element with given index |
TensorListReserve | reserve with given size with empty elements |
TensorListSetItem | set element with given index |
TensorListStack | stack all tensors |
Thanks for the great proposal! Two questions from me:
- Should we generalize the Array/List to support other objects type like shape?
- If I understand it right, relay achieve the loop via recursion since it is a functional language, but now we have side-effect part in the relax program, do we still want to take this approach?
TensorArrayConcat
is very important to efficiently support the common idiom of:
processed = []
for tensor in some_dynamic_list_of_tensors:
processed.append(do_postprocess(tensor))
return stack(processed) # or concat
PT MaskRCNN does such post-processing loop https://github.com/pytorch/vision/blob/d367a01a18a3ae6bee13d8be3b63fd6a581ea46f/torchvision/models/detection/roi_heads.py#L463-L469 which is currently very slow after converted to TVM for various reasons. We can convert that code using tensor_array_concat
, but it is too slow.
We need to be able to express parallel batched concat like PT does https://github.com/pytorch/pytorch/blob/36ddb00240604c328358d4c35cbb042674a8ecf8/aten/src/ATen/native/cuda/Shape.cu#L318.
Also it is important for TensorArrayRead
to be O(1). Currently it is O(N) :man_facepalming:
Also add some folks' discussions about TensorArray in yesterday's relax dev meeting and the discord channel:
@mbs-octoml @jwfromm: In common use cases for example in detection models, the tensors in a TensorArray have the same shape, so in theory we can use a dynamic-shape tensor to represent it. For example, we can use Tensor[n, 2, 3]
to represent an array of tensors with the same shape [2, 3]
.
@junrushao1994: When representing fixed-length list, we can safely convert it to Tuple; When representing variable-length list, each element of which have different shape/dtype/etc, then TensorArray seems to be a good choice.
@yongwww: A general TensorArray can have elements with different shapes (e.g., tf.TensorArray
, List
in TorchScript), and having a general TensorArray support also allows us to support RaggedTensor.
When representing fixed-length list, we can safely convert it to Tuple
What I talked about in discord was that this is not always the case, for static list concat, Tuple
is not ideal. It is "safe", but not efficient.
Also it is important for
TensorArrayRead
to be O(1). Currently it is O(N) 🤦♂️
If our runtime representation of TensorArray is Array<NDArray>
, which Relax VM supports, I think TensorArrayRead
should be O(1). :)
I looked into object detection models (eg, ssd, maskrcnn, centernet), I can confirm the elements of TensorArray/TensorList in these models have the same shape. See the TensorFlow Mark-RCNN-Inception-ResNet-v2 graphdef, or PyTorch MaskRCNN-ResNet50 jit model for an detailed info. about the model structures. Even for the TensorArray with identical element shape, expressing it through Dynamic-shape Tensor is not a preferred option for me. Because a bunch of TensorArray/List related operators, like we listed in the table above, will operate on the TensorArray/TensorList. Those operators are not supported in dynamic Tensor, and it is difficult to enable some of them. Take TensorListGetItem
for example, TensorListGetItem
will get the specific tensor from TensorArray by index, and it could choose to remove the returned tensor from tensorarray after invoking, it is difficult to add this support with dynamic tensor, especially considering we would like getting the specific element in O(1).
Actually a general TensorArray/List is able to support elements with different shape, see the following test for an example. Enabling the support for general tensorarray/tensorlist doesn't hurt, even the variable-shape tensorarray cases don't show up in our target models in this control flow project. Probably we might encounter them at some point.
import tensorflow as tf
ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
ta = ta.write(0, 1)
ta = ta.write(1, (1, 2))
ta = ta.write(2, ((1, 2, 3)))
Changed this topic for TensorArray/TensorList only. I moved control flow related discussion into an independent thread https://github.com/tlc-pack/relax/issues/93, feel free to leave comments regarding control flow there.