tvm
tvm copied to clipboard
[TIR] Enhance `LowerThreadAllreduce` pass to automatically infer shared memory scope
trafficstars
The pass LowerThreadAllreduce enables efficient block reduction. However, block reduction often requires a large amount of shared memory space. The current implementation of LowerThreadAllreduce only enable static shared memory reduce buffer allocation and prevents the shared memory merging when another shared memory scope is defined as shared.dyn.
A_shared: "shared.dyn"
B_shared: "shared.dyn"
C_shared: "shared.dyn"
red: "shared" (can not be merged into the union shared memory pool)
This pull request addresses this issue by first collecting buffer allocations, and then determining the memory scope of the reduction buffer, allowing for memory space fusion in the following MergeSharedMemoryAllocations pass.
Please add a testcase for the enhancement