[Encoding] Add an optional bcast_map attribute to EncodingAttr.
This is a data-tiling requirement for dequant fusion. We will need to encode broadcasting to the encoding when we set encodings on dequant op. This is a placeholder field for the case.
Progress towards https://github.com/iree-org/iree/issues/17718
This is going to get tricky. Do we really need this on the encoding attribute? Could get out of hand pretty quickly.
Yes, I think this is needed for broadcast cases. Say that we have broadcasting_deqaunt + matmul in a dispatch, and we need to allocate a buffer for the input of the dispatch; the original indexing_map from batch_matmul is (b, m, n, k) -> (b, n, k), and we are broadcasting across the batch dimension (i.e., the indexing_map in broadcast op is (b, n, k) -> (n, k)). At stream level, we want to allocate a n * k buffer. If we don't encode the broadcast indexing map, we don't know the allocation size. The missing dimension could be any of b, n and k. I think this is the cost we pay for broadcast fusion.
I think we can't reuse the original indexing_maps field, because it makes logics very very tricky. We could need the original matmul indexing maps to infer contraction dims. Replacing the corresponding indexing map with broadcast map is very bad to me. Decoupling the logic out to a new field is simplifying the complexity. What do you think?
If there are concerns, I can prototype all the things in a branch. And we can revisit the needs of this.
This is going to get tricky. Do we really need this on the encoding attribute? Could get out of hand pretty quickly.
Yes, I think this is needed for broadcast cases. Say that we have
broadcasting_deqaunt + matmulin a dispatch, and we need to allocate a buffer for the input of the dispatch; the original indexing_map from batch_matmul is(b, m, n, k) -> (b, n, k), and we are broadcasting across the batch dimension (i.e., the indexing_map in broadcast op is(b, n, k) -> (n, k)). At stream level, we want to allocate an * kbuffer. If we don't encode the broadcast indexing map, we don't know the allocation size. The missing dimension could be any ofb,nandk. I think this is the cost we pay for broadcast fusion.I think we can't reuse the original indexing_maps field, because it makes logics very very tricky. We could need the original matmul indexing maps to infer contraction dims. Replacing the corresponding indexing map with broadcast map is very bad to me. Decoupling the logic out to a new field is simplifying the complexity. What do you think?
One way to maybe fix it is to change the indexing maps to be an list(list(indexing_maps)). So basically a composition of the indexing maps allows you to get back to the original op. For example for the broadcast case the indexing map would be [[affine_map<(b, m, n, k) -> (b, n, k)>, affine_map<(b,n,k) -> (n, k)>], .., ...] Then this also generalizes (say you want to fuse the transpose in or some arbitrary chain in the future.
One way to maybe fix it is to change the indexing maps to be an list(list(indexing_maps)). So basically a composition of the indexing maps allows you to get back to the original op. For example for the broadcast case the indexing map would be
[[affine_map<(b, m, n, k) -> (b, n, k)>, affine_map<(b,n,k) -> (n, k)>], .., ...]Then this also generalizes (say you want to fuse the transpose in or some arbitrary chain in the future.
I'm not sure we need to go as far as making a list of indexing maps. We should only need one extra indexing map that tells which dimensions correspond to which dimensions in the matmul, and that one indexing map can be composed as encodings are propagated. Having a long list of indexing maps doesn't seem helpful to me, since knowing the path that the tensor took to get to its state is not necessary. All that is needed it to know which dimensions are relevant to the encoding (encoded in a mapping from tensor dims to matmul dims), and how each dimension in the matmul is used (encoded in the current user_indexing_maps). Any other mapping I think is out of the scope of the encoding, since the encoding should really just be telling us how each dimension of the tensor is used.
Keeping track of this dimension correspondence will also be very helpful in simplifying propagation analysis. For example, if an encoding is propagated far enough, then it may reach a point where none of the original matmul dimensions are present in the tensor, at which point the encoding should be able to fold away. This would probably be necessary in many cases, where different encodings may share nodes on their use-def chains, but the shared nodes are some function input like tensor<1xi64> that got broadcasted.
Converting it to draft because we need more information. We will iterate it on a branch, and land it to main later.
Closing this because it was picked up in https://github.com/iree-org/iree/pull/18032