cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG] Logic issue in nondeterministic reduction mode of Stream-K tile scheduler.

Open allispaul opened this issue 1 month ago • 4 comments

Describe the bug The nondeterministic reduction mode of the Stream-K tile scheduler, as described here, is supposed to have all CTAs collaborating on a tile wait for the first one to store its data (to initialize the workspace), and to have the final CTA wait for all others (so that it can load the data from the workspace and compute the epilogue). This appears not to happen in the fixup code. Instead, all non-final CTAs (the !compute_epilogue branch) besides the initial one wait for the previous CTA, as in the deterministic mode; while the final CTA (the else branch) only waits for the initial CTA. In particular, the final CTA can compute the epilogue before all non-initial, non-final CTAs have performed their reduction, leading to incorrect results. It just seems like some of the branches in fixup got swapped around, so it should be pretty simple to fix.

Steps/Code to reproduce bug Below is a reproducing example based on CUTLASS example 49. I believe the issue should trigger when the scheduler assigns at least 3 worktiles to an SM, which is going to depend in part on the specific device being used; on an H100 PCIe with 114 SMs, it triggered consistently for me on 1024x1024xK GEMMs with K >= 4096.

#include <cstdlib>
#include <iostream>

#include "cute/tensor.hpp"

#include "cutlass/cutlass.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"

#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"

using namespace cute;

///////////////////////////////////////////////////////////////////////////////////////////////////

struct Options {
  int m, n, k, l;
  float alpha, beta;
  int seed;

  Options():
    m(1024), n(1024), k(4096), l(1),
    alpha(1.f), beta(0.f),
    seed(0)
  { }

  void parse(int argc, char const **args) {
    cutlass::CommandLine cmd(argc, args);

    cmd.get_cmd_line_argument("m", m, 1024);
    cmd.get_cmd_line_argument("n", n, 1024);
    cmd.get_cmd_line_argument("k", k, 4096);
    cmd.get_cmd_line_argument("l", l, 1);
    cmd.get_cmd_line_argument("seed", seed, 0);
    cmd.get_cmd_line_argument("alpha", alpha, 1.f);
    cmd.get_cmd_line_argument("beta", beta, 0.f);
  }
};

///////////////////////////////////////////////////////////////////////////////////////////////////

#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
template <
  class MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto,
  class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
  class StageCountType = cutlass::gemm::collective::StageCountAuto,
  class TileSchedulerType = cutlass::gemm::PersistentScheduler,
  bool Deterministic = true
>
struct ExampleRunner {

  using LayoutA = cutlass::layout::RowMajor;
  using LayoutB = cutlass::layout::ColumnMajor;
  using LayoutC = cutlass::layout::ColumnMajor;
  using LayoutD = cutlass::layout::ColumnMajor;

  using ElementA = float;
  using ElementB = float;
  using ElementC = float;
  using ElementD = float;
  using ElementAccumulator = float;
  using ElementCompute = float;
  using ElementScalar = float;

  static constexpr int AlignmentA = 16 / sizeof(ElementA);
  static constexpr int AlignmentB = 16 / sizeof(ElementB);
  static constexpr int AlignmentC = 16 / sizeof(ElementC);
  static constexpr int AlignmentD = 16 / sizeof(ElementD);
  static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;

  using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;

  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
      cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
      Shape<_128,_128,_64>, Shape<_1,_1,_1>,
      cutlass::epilogue::collective::EpilogueTileAuto,
      ElementAccumulator, ElementCompute,
      ElementC, LayoutC, AlignmentC,
      ElementD, LayoutD, AlignmentD,
      EpilogueScheduleType,
      DefaultOperation
    >::CollectiveOp;

  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
      cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
      ElementA, LayoutA, AlignmentA,
      ElementB, LayoutB, AlignmentB,
      ElementAccumulator,
      Shape<_128,_128,_64>, Shape<_2,_1,_1>,
      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
      MainloopScheduleType
    >::CollectiveOp;

  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
      Shape<int,int,int,int>,
      CollectiveMainloop,
      CollectiveEpilogue,
      TileSchedulerType
  >;

  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

  using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;

  using StrideA = typename Gemm::GemmKernel::StrideA;
  using StrideB = typename Gemm::GemmKernel::StrideB;
  using StrideC = typename Gemm::GemmKernel::StrideC;
  using StrideD = typename Gemm::GemmKernel::StrideD;

  using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t<StrideA>;
  using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t<StrideB>;
  using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideC>;
  using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideD>;

  StrideA stride_A;
  StrideB stride_B;
  StrideC stride_C;
  StrideD stride_D;

  cutlass::DeviceAllocation<ElementA> block_A;
  cutlass::DeviceAllocation<ElementB> block_B;
  cutlass::DeviceAllocation<ElementC> block_C;
  cutlass::DeviceAllocation<ElementD> block_D;
  cutlass::DeviceAllocation<ElementD> block_ref_D;

  bool verify(const ProblemShapeType& problem_size, float alpha, float beta) {
    auto [M, N, K, L] = problem_size;

    cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K}));
    cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N}));
    cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N}));
    cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N}));

    cutlass::reference::device::GemmComplex(
          {M, N, K},
          ElementScalar(alpha),
          ref_A,
          cutlass::ComplexTransform::kNone,
          ref_B,
          cutlass::ComplexTransform::kNone,
          ElementScalar(beta),
          ref_C,
          ref_D,
          ElementAccumulator(0),
          L,     // batch_count
          M * K, // batch_stride_A
          K * N, // batch_stride_B
          M * N, // batch_stride_C
          M * N  // batch_stride_D
        );

    cudaError_t result = cudaDeviceSynchronize();
    if (result != cudaSuccess) {
      std::cerr << "Reference kernel failed. Last CUDA error: "
                << cudaGetErrorString(result) << std::endl;
      return false;
    }

    ElementD* hD =     static_cast<ElementD*>(malloc(M * N * L * sizeof(ElementD)));
    ElementD* hD_ref = static_cast<ElementD*>(malloc(M * N * L * sizeof(ElementD)));
    cudaMemcpy(hD, block_D.get(), M * N * L * sizeof(ElementD), cudaMemcpyDeviceToHost);
    cudaMemcpy(hD_ref, block_ref_D.get(), M * N * L * sizeof(ElementD), cudaMemcpyDeviceToHost);

    float max_diff = 0.0f;
    int max_idx = 0;
    for (int i = 0; i < M * N * L; ++i) {
      float this_diff = abs(static_cast<float>(hD[i]) - static_cast<float>(hD_ref[i]));
      if (this_diff > max_diff) {
        max_diff = this_diff;
        max_idx = i;
      }
    }
    bool passed = true;
    if (max_diff > 0.0f) {
      passed = false;
      std::cerr.precision(4);
      std::cerr << "Max absolute difference: " << max_diff << " at index " << max_idx
                << ", reference = " << hD_ref[max_idx] << ", obtained = " << hD[max_idx] << std::endl;
    }

    free(hD);
    free(hD_ref);
    return passed;
  }

  void initialize(const ProblemShapeType& problem_size, uint64_t seed) {
    auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
    auto [M, N, K, L] = problem_shape_MNKL;

    stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
    stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
    stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
    stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));

    block_A.reset(M * K * L);
    block_B.reset(K * N * L);
    block_C.reset(M * N * L);
    block_D.reset(M * N * L);
    block_ref_D.reset(M * N * L);

    ElementA* hA;
    ElementB* hB;
    ElementC* hC;

    hA = static_cast<ElementA*>(malloc(M * K * L * sizeof(ElementA)));
    hB = static_cast<ElementB*>(malloc(N * K * L * sizeof(ElementB)));
    hC = static_cast<ElementC*>(malloc(M * N * L * sizeof(ElementC)));

    srand(seed);
    for (int i = 0; i < M * K * L; ++i)
      hA[i] = static_cast<ElementA>(1.0);
      // hA[i] = static_cast<ElementA>(static_cast<double>(rand()) / RAND_MAX - 1);
    for (int i = 0; i < N * K * L; ++i)
      hB[i] = static_cast<ElementB>(1.0);
      // hB[i] = static_cast<ElementB>(static_cast<double>(rand()) / RAND_MAX - 1);
    for (int i = 0; i < M * N * L; ++i)
      hC[i] = static_cast<ElementC>(1.0);
      // hC[i] = static_cast<ElementC>(static_cast<double>(rand()) / RAND_MAX - 1);

    block_A.copy_from_host(hA);
    block_B.copy_from_host(hB);
    block_C.copy_from_host(hC);

    free(hA);
    free(hB);
    free(hC);
  }

  bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
    ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};

    initialize(problem_size, static_cast<uint64_t>(options.seed));

    typename Gemm::Arguments arguments{
      cutlass::gemm::GemmUniversalMode::kGemm,
      problem_size,
      {block_A.get(), stride_A, block_B.get(), stride_B},
      {{}, // epilogue.thread
       block_C.get(), stride_C, block_D.get(), stride_D},
      hw_info
    };

    arguments.epilogue.thread.alpha = options.alpha;
    arguments.epilogue.thread.beta = options.beta;
    using SchedulerParams = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams;

    if constexpr (Deterministic) {
      arguments.scheduler.reduction_mode = SchedulerParams::ReductionMode::Deterministic;
    } else {
      arguments.scheduler.reduction_mode = SchedulerParams::ReductionMode::Nondeterministic;
    }
    // Force a Stream-K schedule
    arguments.scheduler.decomposition_mode = SchedulerParams::DecompositionMode::StreamK;

    Gemm gemm_op;

    size_t workspace_size = Gemm::get_workspace_size(arguments);
    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    cutlass::Status status = gemm_op.can_implement(arguments);
    if (status != cutlass::Status::kSuccess) {
      std::cerr << "This kernel is not supported. Last CUDA error is: "
                << cudaGetErrorString(cudaGetLastError()) << std::endl;
      return false;
    }

    gemm_op.initialize(arguments, workspace.get());
    gemm_op.run();
    cudaError_t result = cudaDeviceSynchronize();
    if (result != cudaSuccess) {
      std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
                << cudaGetErrorString(result) << std::endl;
      return false;
    }

    bool passed = verify(problem_size, options.alpha, options.beta);
    if (!passed) {
      std::cerr << "Reference check failed" << std::endl;
    }

    return passed;
  }

};

#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

///////////////////////////////////////////////////////////////////////////////////////////////////

void print_result(const std::string& description, bool passed) {
  std::cout << description << ": " << (passed ? "Passed" : "Failed") << std::endl;
}

///////////////////////////////////////////////////////////////////////////////////////////////////

int main(int argc, char const **args) {

  cudaDeviceProp props;

  cudaError_t error = cudaGetDeviceProperties(&props, 0);
  if (error != cudaSuccess) {
    std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
    return -1;
  }

  Options options;

  options.parse(argc, args);

#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
  cutlass::KernelHardwareInfo hw_info;

  hw_info.device_id = 0;
  hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

  bool passed;

  ExampleRunner<
    cutlass::gemm::KernelTmaWarpSpecializedCooperative,
    cutlass::epilogue::TmaWarpSpecializedCooperative,
    cutlass::gemm::collective::StageCountAuto,
    cutlass::gemm::StreamKScheduler,
    /*Deterministic=*/false> ws_cooperative_stream_k_schedule_auto_stage_runner_nd;
  passed = ws_cooperative_stream_k_schedule_auto_stage_runner_nd.run(options, hw_info);
  print_result("Nondeterministic", passed);

  ExampleRunner<
    cutlass::gemm::KernelTmaWarpSpecializedCooperative,
    cutlass::epilogue::TmaWarpSpecializedCooperative,
    cutlass::gemm::collective::StageCountAuto,
    cutlass::gemm::StreamKScheduler,
    /*Deterministic=*/true> ws_cooperative_stream_k_schedule_auto_stage_runner;
  passed = ws_cooperative_stream_k_schedule_auto_stage_runner.run(options, hw_info);
  print_result("Deterministic", passed);

#endif

  return 0;
}

allispaul avatar Jan 07 '25 00:01 allispaul