pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

Inconsistent parallelization found with series of trivial reductions

Open IvanYashchuk opened this issue 1 year ago • 3 comments

🐛 Describe the bug

This code is extracted from a portion of torch._decomp.decompositions.native_batch_norm lowered to nvprims and manually translated to C++, maybe there's even more minimal code that fails, but I couldn't come up with it. I initially thought it was a problem with aliasOutputToInput, but found a more minimal example without it.

TEST_F(NVFuserTest, FusionBNMultipleSqueeze_CUDA) {
  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
  Fusion& fusion = *fusion_ptr.get();
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(4);
  auto tv1 = makeSymbolicTensor(1);
  auto tv2 = TensorViewBuilder()
      .ndims(4)
      .shape({1, -1, 1, 1})
      .dtype(DataType::Float)
      .contiguity(std::vector<bool>(4, true))
      .build();
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  fusion.addInput(tv2);

  auto tv3 = set(tv2);
  auto tv4 = expand(tv3, {IrBuilder::create<Int>(32), IrBuilder::create<Int>(-1), IrBuilder::create<Int>(112), IrBuilder::create<Int>(112)});
  auto tv5 = sub(tv0, tv4);
  auto tv6 = sum(tv2, {3}, false);
  auto tv7 = sum(tv6, {2}, false);
  auto tv8 = sum(tv7, {0}, false);
  auto tv9 = add(tv8, tv1);

  fusion.addOutput(tv5);
  fusion.addOutput(tv9);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  auto t0 = at::rand({32, 32, 112, 112}, options);
  auto t1 = at::rand({32}, options);
  auto t2 = at::rand({1, 32, 1, 1}, options);

  FusionExecutorCache fec(std::move(fusion_ptr));
  auto outputs = fec.runFusionWithInputs({t0, t1, t2});
}
}

Fails with:

C++ exception with description "producer->getMemoryType() == MemoryType::Global INTERNAL ASSERT FAILED at "/home/iyashchuk/dev/pytorch/master/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp":457, please report a bug to PyTorch. Inconsistent parallelization found between TV7 (T7_l[ rblockIdx.x119{( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) )}, iblockIdx.y122{( ceilDiv(( 1 * T0.size[1] ), 1) )}, iUS123{1}, rS118{4}, rthreadIdx.x120{blockDim.x} ] produce_pos( 3)) and TV8(T8_l[ iblockIdx.y125{( ceilDiv(( 1 * T0.size[1] ), 1) )}, iUS126{1} ] ca_pos( 2 )). Producer is required to be in Global Memory based on parallelization strategy.
Exception raised from build at /home/iyashchuk/dev/pytorch/master/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp:457

Based on the error message a simple fix seems to be auto tv8 = sum(set(tv7), {0}, false);. Is this error expected and is the fix correct thing to do?

Versions

devel branch

IvanYashchuk avatar Sep 14 '22 16:09 IvanYashchuk