heterocl icon indicating copy to clipboard operation
heterocl copied to clipboard

[WIP] Pattern-based design transform

Open hanchenye opened this issue 2 years ago • 4 comments

Add feature: Added a pattern library which contains a Pattern class for users to define pattern-rewrite/transform rules in Python. Meanwhile, an apply() API is added to the Schedule class to apply the user-defined patterns to an established schedule.

How to use the new feature: Patterns are defined as a Python function that always takes an Pattern object as the first argument:

def pattern1(p: hcl.Pattern):
    dtype = hcl.Int(32)

    v = p.value(dtype)
    z = p.value(dtype)
    a = p.value(dtype)
    res = z + a * v

    res = p.start_transform(res)
    target_loop = p.get_parent_loop(res)
    p.loop_unroll(target_loop, 2)
    p.end_transform_or_rewrite()

value() API of Pattern class returns a ValueHandle and is used to match any MLIR values in the schedule that has the specified type. Expressions between ValueHandles is overloaded to match any MLIR arithmetic operations between the corresponding MLIR values. start_transform() and end_transform_or_rewrite() APIs indicate the transform session of the pattern, where the matched operation, in this case, res, is transformed with the specified rules. In this case, we first get the parent loop of res using the get_parent_loop() API, then we unroll the loop with a factor of 2 using the loop_unroll() API.

Once a pattern is defined, we can apply the pattern to a schedule using the apply() API as follows:

    def montgomery_inv(A, M, v, k, z):
        ... ...

    s = hcl.create_schedule([Ap, Mp, vp, kp, zp], montgomery_inv)
    p = s.apply(name = "pattern1", benefit = 0, pattern1)

Detailed description: Under the hood, this patch leverages the PDL and Transform dialect in MLIR to drive the transformation. The user-defined patterns with the Pattern APIs are first translated to the corresponding PDL and Transform IRs and then the apply() API calls the interpreter coming with the Transform dialect to apply the transformation to the schedule.

Also see https://github.com/cornell-zhang/hcl-dialect-prototype/pull/124 for more information.

Link to the tests: Only a trivial test is included to showcase this new feature.

hanchenye avatar Aug 03 '22 21:08 hanchenye

Did you forget uploading the montgomery.py file? I didn't see it in the PR, but it was used in the test file.

chhzh123 avatar Aug 04 '22 02:08 chhzh123

@hanchenye can you provide the complete list of transforms that the pattern class currently supports? Also is there a difference between transform and rewrite?

I would prefer we shorten the names of some of the APIs -- p.loop_unroll => p.unroll; p.start_transform => p.start_rewrite(); p.end_transform_or_rewrite() => p.end_rewrite()

zhangzhiru avatar Aug 04 '22 16:08 zhangzhiru

Did you forget uploading the montgomery.py file? I didn't see it in the PR, but it was used in the test file.

@chhzh123 I've updated the test file. It's no longer dependent on montgomery,py.

hanchenye avatar Aug 06 '22 20:08 hanchenye

@zhangzhiru Thanks for the feedback.

can you provide the complete list of transforms that the pattern class currently supports?

The transforms supported in this patch include:

  • parent_loop
  • split
  • unroll
  • pipeline

More transforms will be supported in subsequent patches.

I would prefer we shorten the names of some of the APIs -- p.loop_unroll => p.unroll; p.start_transform => p.start_rewrite(); p.end_transform_or_rewrite() => p.end_rewrite()

I've updated the APIs with decorators, which are more intuitive comparing to the previous APIs:

@is_transform
def loop_transform(target):
    target_loop = parent_loop(target, 1)
    outer_loop, inner_loop = split(target_loop, 2)
    unroll(inner_loop, 2)
    pipeline(outer_loop, 1)

@is_pattern(benefit=0)
def pattern1():
    dtype = hcl.Int(32)
    a = value(dtype)
    b = value(dtype)
    c = value(dtype)
    res = a * b + c
    loop_transform(res)

def main(M=32, N=32, K=32):
    ... ...
    def gemm(A, B):
        ... ...

    s = hcl.create_schedule([A, B], gemm)
    s.apply(pattern1)

Also is there a difference between transform and rewrite?

With the new API, @is_pattern function can be terminated by a call to either an @is_transform function or an @is_rewrite function. For now, they are using different backend for the IR transformation. @is_transform functions use the interpreter infra provided by Transform dialect, which can be extended to support the complicated transforms, such as split and unroll shown in the example above. @is_rewrite functions use the infra provided by PDL dialect, which only supports simple operation replace and erase and cannot be extended. Ultimately, we should figure out a way to merge these two and only expose one entry to users at the python level.

hanchenye avatar Aug 06 '22 21:08 hanchenye