triton icon indicating copy to clipboard operation
triton copied to clipboard

Generalize loop pipelining

Open htyu opened this issue 1 year ago • 4 comments

We think generally pipelining loads in a loop, regardless of whether they are involved in dot operations can help performance, by mitigating hardware scheduling. The change enables that.

htyu avatar Feb 07 '24 19:02 htyu

I'm sending out this diff to get early feedbacks. Regarding the performance testing, I'm still looking for memory-bound kernels with heavy computations. Please share if you have such kernels. The kernels may need to be rewritten in a persistent way.

htyu avatar Feb 07 '24 19:02 htyu

if loop_annotation || (matmul_loop && global_num_stage > 1)

Sounds good to check against the annotation. How do you think we should handle matmul loops with extra loads? E.g, one of the test changes in the patch is for that. A load of an indexing tensor which is in turn used to load the one dot operand is pipelined.

htyu avatar Feb 08 '24 01:02 htyu

A load of an indexing tensor which is in turn used to load the one dot operand is pipelined.

Hey @htyu I am finishing up PR that will handle exactly the case of dependent loads in the loop. There is some overlap between our changes, but mine is more specific, trying to get optimal pipelining for chained loads. I do not handle other loads for now.

pawelszczerbuk avatar Feb 08 '24 18:02 pawelszczerbuk

A load of an indexing tensor which is in turn used to load the one dot operand is pipelined.

Hey @htyu I am finishing up PR that will handle exactly the case of dependent loads in the loop. There is some overlap between our changes, but mine is more specific, trying to get optimal pipelining for chained loads. I do not handle other loads for now.

Thanks for letting me know. I'll probably restrict my current change to annotated loop only.

htyu avatar Feb 08 '24 21:02 htyu

should this be moved to a draft?

ThomasRaoux avatar Feb 21 '24 15:02 ThomasRaoux

should this be moved to a draft?

Do we decide to not support this feature int the short term?

If we also add the loop_range annotation, could we merge this PR?

Jokeren avatar Feb 21 '24 16:02 Jokeren

should this be moved to a draft?

Do we decide to not support this feature int the short term?

If we also add the loop_range annotation, could we merge this PR?

I think we can support this feature, I was asking as the PR is out of date right now but it is fine to me if we want to rebase and add support for non-matmul loop pipelining. The loop_range annotation is already supported and I adding pipelining for non-matmul loops based on that makes sense.

ThomasRaoux avatar Feb 21 '24 16:02 ThomasRaoux

I think we can support this feature, I was asking as the PR is out of date right now but it is fine to me if we want to rebase and add support for non-matmul loop pipelining. The loop_range annotation is already supported and I adding pipelining for non-matmul loops based on that makes sense.

I'll go make changes on top of this PR. It turns out that it might help for one use case I have.

Jokeren avatar Feb 21 '24 16:02 Jokeren

I think we can support this feature, I was asking as the PR is out of date right now but it is fine to me if we want to rebase and add support for non-matmul loop pipelining. The loop_range annotation is already supported and I adding pipelining for non-matmul loops based on that makes sense.

I'll go make changes on top of this PR. It turns out that it might help for one use case I have.

Good to know it's helping your case. Would you mind sharing what your case looks like?

In the latest version I made the optimization restricted to loops with tt.num_stages annotation only. Is loop_range a different annotation?

htyu avatar Feb 21 '24 17:02 htyu

Good to know it's helping your case. Would you mind sharing what your case looks like?

Unfortunately I'm not able to share the code since it's an ongoing project.

It's something like:

for i in range(...)
    a += tl.load(...)
tl.dot(a, b)

In the latest version I made the optimization restricted to loops with tt.num_stages annotation only. Is loop_range a different annotation?

Oh, in this case I think we just need to update the code and consolidate test cases

Jokeren avatar Feb 21 '24 17:02 Jokeren

In the latest version I made the optimization restricted to loops with tt.num_stages annotation only. Is loop_range a different annotation?

Oh, in this case I think we just need to update the code and consolidate test cases

+1, thanks

ThomasRaoux avatar Feb 21 '24 17:02 ThomasRaoux

I will do a rebasing. The test case (in test/TritonGPU/loop-pipeline.mlir) has been updated with that loop annotation. Let me know if that looks good. Thanks.

htyu avatar Feb 21 '24 17:02 htyu

Rebasing done.

htyu avatar Feb 22 '24 01:02 htyu

One final comment is that maybe we want to lift the logic out of MatmulLoopPipeline later since it's not "matmul" anymore?

Jokeren avatar Feb 22 '24 15:02 Jokeren

One final comment is that maybe we want to lift the logic out of MatmulLoopPipeline later since it's not "matmul" anymore?

It's a good point. Or maybe we could rename MatmulLoopPipeline.cpp to be something more general like LoopPipelining.cpp ? I'll leave it to a separate refactoring patch.

htyu avatar Feb 22 '24 17:02 htyu