tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] ThreadStorageSync Pass must be put after MergeSharedMemory Pass

Open LeiWang1999 opened this issue 1 year ago • 0 comments
trafficstars

In our current lowering pipeline, ThreadSync is placed before the MergeSharedMemoryAllocations Pass, which may lead to unknown behaviors because MergeSharedMemoryAllocations will modify the buffer access region.

https://github.com/apache/tvm/blob/main/src/driver/driver_api.cc#L585-L613

  bool detect_global_barrier =
      pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value();
  if (detect_global_barrier) {
    mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
  }


  mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
  mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
  mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
  mixed_pass_list.push_back(tir::transform::InferFragment());
  mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());


  bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();


  if (use_async_copy) {
    mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
  }


  bool ptx_ldg32 = pass_ctx->GetConfig<Bool>("tir.ptx_ldg32", Bool(false)).value();
  if (ptx_ldg32) {
    mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
  }


  mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
  mixed_pass_list.push_back(tir::transform::SplitHostDevice());
  // MergeSharedMemoryAllocations must be applied after SplitHostDevice
  // because the merged allocation site is at the beginning of each device function
  mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());

Given a simple matmul schedule pipeline:

Store A_shared
Store B_shared

tvm_storage_sync

Load A_shared
Load B_shared

Store C_shared
tvm_storage_sync
Load C_shared

The ThreadSync pass will not inject a tvm_storage_sync before Store C_shared, that's make sense because C_shared is a non-interfering memory with A_shared and B_shared.

However, when we merge shared memory, C_shared will reuse the memory space with A_shared and B_shared.

Store A_shared
Store B_shared

tvm_storage_sync

Load A_shared
Load B_shared

Store C_shared(reuse memory space with A_shared and B_shared)
tvm_storage_sync
Load C_shared(reuse memory space with A_shared and B_shared)

which is supposed to be a tvm_storage_sync statement before Store C_shared, otherwise may lead to a unknown behavior (random and small incorrect produce) because the Store C_shared may change the elements in Load A_shared.

And the solution is quite simple, put the ThreadStorageSync Pass after MergeSharedMemory Pass.

LeiWang1999 avatar Oct 04 '24 17:10 LeiWang1999