tvm
tvm copied to clipboard
[Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat
This commit implements an optional optimization pass relax.transform.ReorderPermuteDimsAfterConcat, which reorder expressions of the form R.concat(R.permute_dims(A), R.permute_dims(B)) into R.permute_dims(R.concat(A,B)).
This pass is intended to be used alongside CombineParallelMatmul. After parallel matmuls are combined, to be lifted out, and optimized nn.Linear kernels to find the R.matmul(x, R.permute_dims(weights)) patterns they are looking for.
@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
"""Initial IRModule
The `R.permute_dims` followed by `R.matmul` is the relax
equivalent of `nn.Linear`, and will frequently have optimized
kernels.
"""
weight_query_T = R.permute_dims(weight_query)
query = R.matmul(x, weight_query)
weight_key_T = R.permute_dims(weight_key)
key = R.matmul(x, weight_key)
weight_value_T = R.permute_dims(weight_value)
value = R.matmul(x, weight_value)
@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
"""After `CombineParallelMatmul`
There's now only a single matmul to be performed, which is
generally better than performing three small matmuls. However,
the optimized kernels for `nn.Linear` can no longer be applied,
because the `R.concat` isn't part of the expected pattern.
"""
weight_query_T = R.permute_dims(weight_query)
weight_key_T = R.permute_dims(weight_key)
weight_value_T = R.permute_dims(weight_value)
fused_weight_T = R.concat([weight_query_T, weight_key_T, weight_value_T], axis=1)
fused_qkv = R.matmul(x, fused_weight_T)
query, key, value = R.split(fused_qkv)
@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
"""After `ReorderPermuteDimsAfterConcat`
There's still only a single matmul, and the optimized kernels for
`nn.Linear` can be applied again.
"""
fused_weight = R.concat([weight_query, weight_key, weight_value], axis=0)
fused_weight_T = R.permute_dims(fused_weight)
fused_qkv = R.matmul(x, fused_weight_T)
query, key, value = R.split(fused_qkv)