tvm
tvm copied to clipboard
[SVE] Add support for representing and creating buffer-level predicates
Representation
This commit extends BufferLoad and BufferStore to accept a predicate mask argument indicating which lanes in a vectorized buffer load/store should be read/written.
As a simple example, we can load all lanes:
tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(1, 8))
Or disable loading all lanes:
tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(0, 8))
In TVMScript, buffer loads and stores are currently displayed using a "short-hand" notation e.g. A[0:4], but there was no clear path for extending this notation to support predicates. Therefore, the vload/vstore notation is used e.g. A.vload([T.Ramp(0, 1, 4)], predicate=...). The TVMScript printer falls back to the vload/vstore notation whenever predicates are specified.
Creation
Buffer-level predication becomes more motivating when combined with the tir.get_active_lane_mask intrinsic. It can be used to mask off lanes when the vectorized axis is not divisible by the vector length. A detailed example and rationale can be found in the RFC.
Predicated buffer load/stores are created in the VectorizeLoop pass via TryPredicateBufferAccesses. This pass aims to convert block-level predicates e.g.
for i_0 in T.serial(4):
for i_1 in T.vectorized(4):
if i_0 * 4 + i_1 < 14:
B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0
to buffer-level predicates, e.g.
for i_0 in T.serial(4):
predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14)
A_load = T.meta_var(A.vload([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate))
B.vstore([T.Ramp(i_0 * 4, 1, 4)], A_load, predicate=predicate)
It takes a conservative approach for now, focussing only on expressions produced by the split scheduling primitive, but more complex expressions could be supported in the future.
TryPredicateBufferAccesses can be explicitly enabled/disabled with the tir.enable_buffer_level_predication pass context option. By default it will be disabled, unless the target supports SVE, in which case it will be enabled by default.
~Note: this commit depends on https://github.com/apache/tvm/pull/16965, so also contains the contents of https://github.com/apache/tvm/pull/16965.~
Co-authored-by: Elen Kalda [email protected] Co-authored-by: Neil Hickey [email protected]
cc @ekalda @tqchen @Lunderberg @cbalint13 @Anndrey24
No problem, and thank you on the revisions!
@tvm-bot rerun
friendly ping @Lunderberg if you have some free time
Apologies for not getting back to the review, and thank you for making the changes!