[VectorDistribution] Add support for multi-subgroup attention
This patch adds support for distributing attention to multiple subgroups. Some points to note:
- Due to some issues with layout analysis, we cannot yet do multiple n subgroups. This is fine, because by the nature of attention, the n dimension is really small and generally not distributed (The flash attention paper doesn't even consider this dimension for distribution!).
- Currently, we distinguish the two matmuls in attention by setting a discardable attribute on the matmuls (set during decomposition) used as a hint to layout anchoring, on what to do when it encounters these matmuls. (Note that even if these hints were dropped, it would only lead to a drop in performance, because the layout anchoring doesn't know its attention anymore). The correct way to handle these matmuls would be to start putting mma_schedule as an operation specific lowering config and teach decomposition to propagate this lowering to the two matmuls after decomposition. This is blocked by work on TileAndDistributeToWorkgroups supporting consumer fusion, and needs some heavy lifting.
Is there a way you can avoid maybe decomposing attention operation until vector distribution and handle the layout distribution for attention directly?
Today, IIUC we'd need to decompose earlier than vector distribution because attention decomposes into non-trivial ops such as matmuls, shuffles/reductions (to a lesser extend reads, and broadcasts) which requires layout analysis and vector distribution to ensure the thread-distributed shapes play nice with each other.
Is there a way you can avoid maybe decomposing attention operation until vector distribution and handle the layout distribution for attention directly?
Probably not… if we do that, we would effectively be writing microkernels for attention hardcoded for each intrinsic type at thread level. Which is fine… but not sure if we want to do that.
one thing we could do is do subgroup distribution at attention op level and do thread distribution after decomposition. This would require a major rerwite of vector distribution, splitting it up into subgroup and thread level distribution. Im also not sure if we can properly split things up also.
id rather land this patch, and invest effort in teaching TileAndFuse to do attention instead of rerwiting VectorDistribution.
Is there a way you can avoid maybe decomposing attention operation until vector distribution and handle the layout distribution for attention directly?
Probably not… if we do that, we would effectively be writing microkernels for attention hardcoded for each intrinsic type at thread level. Which is fine… but not sure if we want to do that.
one thing we could do is do subgroup distribution at attention op level and do thread distribution after decomposition. This would require a major rerwite of vector distribution, splitting it up into subgroup and thread level distribution. Im also not sure if we can properly split things up also.
Well, we could also just decompose within the pass as a "pre-processing". Then the attribute becomes an internal detail of the pass.
id rather land this patch, and invest effort in teaching TileAndFuse to do attention instead of rerwiting VectorDistribution.
Ok, I stamped it, but please add TODO/warnings as to this being unstable.
There are some tests that exceed shared memory, so i'm going to wait for https://github.com/iree-org/iree/pull/18415 to land before i land this.