iree icon indicating copy to clipboard operation
iree copied to clipboard

Missing propagation for `unpack -> collapse_shape` to `collpase_shape -> unpack`.

Open hanhanW opened this issue 1 year ago • 4 comments

It stops the fusion for unpack + consumers. E.g., we should be able to swap unpack and collapse_shape because it is just folding unit dims away.

  func.func @foo(%arg0: tensor<1x1024x1024x16x16xf32>) -> tensor<16384x16384xf32> {
    %0 = tensor.empty() : tensor<1x16384x16384xf32>
    %unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %0 : tensor<1x1024x1024x16x16xf32> -> tensor<1x16384x16384xf32>
    %collapsed = tensor.collapse_shape %unpack [[0, 1], [2]] : tensor<1x16384x16384xf32> into tensor<16384x16384xf32>
    %1 = tensor.empty() : tensor<16384x16384xf32>
    %2 = linalg.softmax dimension(1) ins(%collapsed : tensor<16384x16384xf32>) outs(%1 : tensor<16384x16384xf32>) -> tensor<16384x16384xf32>
    return %2 : tensor<16384x16384xf32>
  }

hanhanW avatar May 22 '24 00:05 hanhanW

It should be done in https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp

hanhanW avatar May 22 '24 01:05 hanhanW

Where is this collapse shape coming from. There might be a uniform way of handling this in the reshape propagation passes later on.

MaheshRavishankar avatar May 22 '24 01:05 MaheshRavishankar

I don't know. It is here after set encoding. A sequence of linalg ops are raised to softmax op in GlobalOptimization stage. Are we able to push down reshape ops on named op? It looks not easy to me, so I think we can implement a (unpack, collapse_shape) propagation pattern in this case.

hanhanW avatar May 22 '24 02:05 hanhanW

I don't know. It is here after set encoding. A sequence of linalg ops are raised to softmax op in GlobalOptimization stage. Are we able to push down reshape ops on named op? It looks not easy to me, so I think we can implement a (unpack, collapse_shape) propagation pattern in this case.

The propogation patterns are implemented for Linalg ops, but we can add propagation patterns for other ops as well. I'd like to consolidate in one place all the propagation patterns if possible. We can still add those patterns, but we should be able to use them in the reshape propagation passes.

MaheshRavishankar avatar May 22 '24 02:05 MaheshRavishankar