torch-mlir
torch-mlir copied to clipboard
Adding support for list of tensors as return type.
Some of the operations like aten.split
and aten.split_with_sizes
return a list of tensors. This is currently not supported in torch-mlir.
def split_with_sizes(self: Tensor, split_sizes: List[int], dim: int = 0) -> List[Tensor]:
num_splits = len(split_sizes)
splits = []
start_idx = 0
for i in range(num_splits):
length = split_sizes[i]
splits.append(self.narrow(dim, start_idx, length))
start_idx += length
return splits
def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
input_sizes = self.shape
dim_size = input_sizes[dim]
if split_size == 0:
assert(dim_size == 0)
return [self]
chunks = (dim_size + split_size - 1) // split_size
split_sizes = [split_size for i in range(chunks)]
split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
return aten.split_with_sizes(self, split_sizes, dim)
I have started working on this. Please share some thoughts over this issue. CC: @silvasean @cathyzhyi
If the number of splits is static, then we can lower it to a fixed set of SSA values. Otherwise, it requires a runtime list type which we don't have.
Generally, to use the decompositions, I think we need to do further design work similar to the shape library. Anush and I discussed in our sync -- I think he should be reaching out to file an issue and do some design work here.
Hi @silvasean could you please elaborate on how I can add runtime list type in torch-mlir? It would be great if you could provide some starting point to look at? Even if we want to lower the split operation in torch-mlir, we would need that support.
Adding a runtime list type is quite difficult and not a direction that I think is the best for the project right now. Do you actually have a model where the number of splits is dynamic?
I think we need to take a step back and look at the problem of decompositions from the ground up. cc @powderluv
Adding a runtime list type is quite difficult and not a direction that I think is the best for the project right now. Do you actually have a model where the number of splits is dynamic?
I think we need to take a step back and look at the problem of decompositions from the ground up. cc @powderluv
Thanks! As discussed over chat, I will try to lower this operation in ListConstructOp
for static offsets.
@gprateek93 can we close this?