[NestedLayout] Make layout use strides instead of basis
This patch makes the layout store how threads and subgroups are distributed as "strides" instead of basis.
Strides are just a mapping from virtual tid --> tid (where tid is thread id). The mapping is a dot product between virtual tid and strides:
tid = vtid * stride
If we put some restrictions on strides and make the mapping can be made invertible with the inverse mapping:
vtid_i = tid floordiv stride_i mod size_i
The advantage with strides is that strides are only dependent on the dimension they are working on, while basis actually depend on the order of dimension because to get the distribution for dimension i, you need to know the distribution for dimensions i-1. (If you think about it, this is just skewing to remove dependency ofi on i-1).
With strides, dimensions are essentially independent, and we do not need any orders in the layout, which hugely simplifies the layout. This is much closer to what CuTe layouts do (somewhat).
Along with this change, distribution now actually uses 2 new ops to query the tid -> virtual tid map, which makes distribution independent of these changes.
The reason for implementing this was that with strides, the layouts actually have a canonical form, unlike with basis, because different active ids, order, basis permutations can represent the same layout. Having a canonical form helps a lot when try to match two different matmul anchors.
TODO:
- Check correctness on matmul test suite for mfma (DONE)
- Check correctness on matmul test suite for wmma (DONE)
- Check sdxl correctness
- Fix tests (DONE)
- Add loads of docs for the layout
I have proofs that this is equivalent to the basis form and how the bidrectional tid -> virtual tid mapping works. I will include them in docs once I have checked correctness.
@Groverkss Bravo on this PR! Feels so much cleaner and simpler! Just have couple Qs and NITs :)