tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Transform] Implement relax.transform.ReorderPermuteDimsAfterConcat

Open Lunderberg opened this issue 1 year ago • 0 comments

This commit implements an optional optimization pass relax.transform.ReorderPermuteDimsAfterConcat, which reorder expressions of the form R.concat(R.permute_dims(A), R.permute_dims(B)) into R.permute_dims(R.concat(A,B)).

This pass is intended to be used alongside CombineParallelMatmul. After parallel matmuls are combined, to be lifted out, and optimized nn.Linear kernels to find the R.matmul(x, R.permute_dims(weights)) patterns they are looking for.

@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """Initial IRModule

    The `R.permute_dims` followed by `R.matmul` is the relax
    equivalent of `nn.Linear`, and will frequently have optimized
    kernels.
    """
    weight_query_T = R.permute_dims(weight_query)
    query = R.matmul(x, weight_query)
    weight_key_T = R.permute_dims(weight_key)
    key = R.matmul(x, weight_key)
    weight_value_T = R.permute_dims(weight_value)
    value = R.matmul(x, weight_value)

@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """After `CombineParallelMatmul`

    There's now only a single matmul to be performed, which is
    generally better than performing three small matmuls.  However,
    the optimized kernels for `nn.Linear` can no longer be applied,
    because the `R.concat` isn't part of the expected pattern.
    """
    weight_query_T = R.permute_dims(weight_query)
    weight_key_T = R.permute_dims(weight_key)
    weight_value_T = R.permute_dims(weight_value)

    fused_weight_T = R.concat([weight_query_T, weight_key_T, weight_value_T], axis=1)
    fused_qkv = R.matmul(x, fused_weight_T)

    query, key, value = R.split(fused_qkv)

@R.function
def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, weight_value: R.Tensor):
    """After `ReorderPermuteDimsAfterConcat`

    There's still only a single matmul, and the optimized kernels for
    `nn.Linear` can be applied again.
    """
    fused_weight = R.concat([weight_query, weight_key, weight_value], axis=0)

    fused_weight_T = R.permute_dims(fused_weight)
    fused_qkv = R.matmul(x, fused_weight_T)

    query, key, value = R.split(fused_qkv)

Lunderberg avatar Feb 16 '24 19:02 Lunderberg