tvm
tvm copied to clipboard
[Bugfix] Fix improper touched buffer assignment of Pass MergeSharedMemoryAllocations
As discussed in issue #17375, the current rule for assigning touched buffers is not appropriate. Consider the following example:
code_block_0
for k in range(0, 10): # (the gen point of A_shared and B_shared will be injected into this for expression)
for i in range(0, 10):
A_shared <- A
for i in range(0, 10):
B_shared <- B
code_block_1 (consume A_shared and B_shared)
code_block_2 (produce and consume C_shared)
This setup works by chance in simple GEMM scenarios. However, the correct approach should be
code_block_0
for k in range(0, 10):
for i in range(0, 10):
A_shared <- A # (the gen point of A_shared should be bind into this BufferStore Node)
for i in range(0, 10):
B_shared <- B # (the gen point of B_shared be bind into this BufferStore Node)
code_block_1 (consume A_shared and B_shared)
code_block_2 (produce and consume C_shared)
This approach works correctly even in more complex scenarios, such as batched GEMM, where the naive template would fail.
This pull request made a simple modification for MergeSharedMemory Pass to enable the right analysis, and always disable the naive naive shared memory buffer fuse if kernel with dynamic in StorageRewrite Pass