tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[SVE] Add support for representing and creating buffer-level predicates

Open lhutton1 opened this issue 1 year ago • 4 comments

Representation

This commit extends BufferLoad and BufferStore to accept a predicate mask argument indicating which lanes in a vectorized buffer load/store should be read/written.

As a simple example, we can load all lanes:

tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(1, 8))

Or disable loading all lanes:

tir.BufferLoad(buf, [tir.Ramp(0, 1, 8)], predicate=tir.Broadcast(0, 8))

In TVMScript, buffer loads and stores are currently displayed using a "short-hand" notation e.g. A[0:4], but there was no clear path for extending this notation to support predicates. Therefore, the vload/vstore notation is used e.g. A.vload([T.Ramp(0, 1, 4)], predicate=...). The TVMScript printer falls back to the vload/vstore notation whenever predicates are specified.

Creation

Buffer-level predication becomes more motivating when combined with the tir.get_active_lane_mask intrinsic. It can be used to mask off lanes when the vectorized axis is not divisible by the vector length. A detailed example and rationale can be found in the RFC.

Predicated buffer load/stores are created in the VectorizeLoop pass via TryPredicateBufferAccesses. This pass aims to convert block-level predicates e.g.

for i_0 in T.serial(4):
    for i_1 in T.vectorized(4):
        if i_0 * 4 + i_1 < 14:
            B[i_0 * 4 + i_1] = A[i_0 * 4 + i_1] + 1.0

to buffer-level predicates, e.g.

for i_0 in T.serial(4):
    predicate = T.get_active_lane_mask("int1x4", i_0 * 4, 14)
    A_load = T.meta_var(A.vload([T.Ramp(i_0 * 4, 1, 4)], predicate=predicate))
    B.vstore([T.Ramp(i_0 * 4, 1, 4)], A_load, predicate=predicate)

It takes a conservative approach for now, focussing only on expressions produced by the split scheduling primitive, but more complex expressions could be supported in the future.

TryPredicateBufferAccesses can be explicitly enabled/disabled with the tir.enable_buffer_level_predication pass context option. By default it will be disabled, unless the target supports SVE, in which case it will be enabled by default.

~Note: this commit depends on https://github.com/apache/tvm/pull/16965, so also contains the contents of https://github.com/apache/tvm/pull/16965.~

Co-authored-by: Elen Kalda [email protected] Co-authored-by: Neil Hickey [email protected]

lhutton1 avatar May 03 '24 09:05 lhutton1

cc @ekalda @tqchen @Lunderberg @cbalint13 @Anndrey24

lhutton1 avatar May 07 '24 11:05 lhutton1

No problem, and thank you on the revisions!

Lunderberg avatar May 13 '24 16:05 Lunderberg

@tvm-bot rerun

lhutton1 avatar May 21 '24 08:05 lhutton1

friendly ping @Lunderberg if you have some free time

lhutton1 avatar May 22 '24 12:05 lhutton1

Apologies for not getting back to the review, and thank you for making the changes!

Lunderberg avatar May 29 '24 14:05 Lunderberg