[GPU] Support multiple contraction dims in MmaSchedules
This adds support for multiple M, N, and K dims in problems when deducing a GPUMMASchedule. The new heuristic is similar to the old one, but works on pairs of M and N dims. For example:
tensor<M1xM0xK1xK0> * tensor<N1xN0xK1xK0> -> tensor<M1xN1xM0xN0>
This will try to distribute the seeded tile counts to M0 and N0 (first attempting to distribute evenly, and then distributing to N followed by N), and then distribute the residual counts to M1 and N1. The K tile counts will be partitioned to K0 first, and then the residual tile counts will be partitioned to K1.
This PR also updates the config selection logic for the TileAndFuse pipeline to make use of the multiple contraction dimensions in mma schedules.
Depends on https://github.com/iree-org/iree/pull/18565
Oh man, the tuner will need some updating after this lands...
Oh man, the tuner will need some updating after this lands...
Oh, I would have thought this doesn't break the tuner. Does the tuner use the logic here somehow?
Oh man, the tuner will need some updating after this lands...
Oh, I would have thought this doesn't break the tuner. Does the tuner use the logic here somehow?
We discussed this offline and I misunderstood this part of the PR: the configuration logic changes but the attributes do not, so no need to update the tuner.
I understand how subgroup and workgroup tiles are distributed with multiple m/n/k dimensions. But slightly confused what happens with the thread tiles. Are we only targeting the innermost dimensions? Or are we targeting a reshaped version of the innermost dimensions (by packing?) ?
The m/n/kTileCounts naming was confusing, and it is more like m/n/kTileSizes. We target the intrinsic shape to the innermost dimension, so the innermost m/n/kTileSizes need to be aligned with (or able to pad to) the intrinsic shapes. The outer tile sizes can be thought of as unrolling factors of the innermost tile sizes. I added some more docs to help explain this, LMK what you think.
While VectorDistribute doesn't support multiple dimensions for subgroup dims, can we try to keep the configuration logic similar to TileAndFuse? We plan to soon support that, and It would involve changing a lot of things this patch already did. I mentioned it at some places, but would be nice if we can get them to be similar. This is ofcourse best effort request, and not blocking in any way.
Thanks for all the comments! I basically just did a find and replace for all the VectorDistribute logic to use only the first dimension for M, N, and K. I chose to leave the functionality the same because I don't understand the VectorDistribute pipeline well enough to know if we can/how to change the configuration logic to support multiple contraction dimensions. I will just leave it as TODO comments for now, since I don't have enough knowledge of VectorDistribute to support it properly.
I'm happy to help with adding the support for VectorDistribute (and attention schedules) as well, but I think it is best to do it later, since this PR is needed for flipping the IGEMM flag.
While VectorDistribute doesn't support multiple dimensions for subgroup dims, can we try to keep the configuration logic similar to TileAndFuse? We plan to soon support that, and It would involve changing a lot of things this patch already did. I mentioned it at some places, but would be nice if we can get them to be similar. This is ofcourse best effort request, and not blocking in any way.
Thanks for all the comments! I basically just did a find and replace for all the VectorDistribute logic to use only the first dimension for M, N, and K. I chose to leave the functionality the same because I don't understand the VectorDistribute pipeline well enough to know if we can/how to change the configuration logic to support multiple contraction dimensions. I will just leave it as TODO comments for now, since I don't have enough knowledge of VectorDistribute to support it properly.
I'm happy to help with adding the support for VectorDistribute (and attention schedules) as well, but I think it is best to do it later, since this PR is needed for flipping the IGEMM flag.
Sure, that's fair enough. TODO comments sound good. Thanks! I'll let Quinn take the final approval on this since he mentioned he wanted to have a look at this. Soft approval from my side.