cutlass
cutlass copied to clipboard
[BUG] Python `EVT` `Pytorch` Emitter Broken
Describe the bug
The Python pytorch emitter does not output functioning code when compiling Gemm
with an EVT
.
Steps/Code to reproduce bug The script below reproduces the bug.
Switch jit
to True
when calling cutlass.emit.pytorch
to see the generated code (see additional context, as well).
import torch
import cutlass
from cutlass import Tensor as FakeTensor
print_module = True
m = 8
n = 8
k = 8
type_A = torch.float16
type_B = torch.float16
type_C = torch.float16
type_D = torch.float16
tensor_A = torch.arange(m * k, dtype=type_A, device="cuda").reshape(m, k)
tensor_B = torch.ones(n * k, dtype=type_B, device="cuda").reshape(k, n)
tensor_C = torch.zeros(m * n, dtype=type_C, device="cuda").reshape(m, n)
tensor_D = torch.zeros_like(tensor_C)
plan = cutlass.op.Gemm(
element=torch.float16,
layout=cutlass.LayoutType.RowMajor,
element_accumulator=torch.float32,
)
def epilogue_scale(accum, scale):
D = scale * accum
return D
# Construct inputs and outputs
scale = torch.arange(m, dtype=type_C, device="cuda").reshape(m, 1)
examples_tensors = {
"accum": FakeTensor(
element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor
),
"scale": scale,
"D": tensor_D,
}
epilogue_visitor = cutlass.epilogue.trace(epilogue_scale, examples_tensors)
visitor_args = {"scale": scale, "D": tensor_D}
plan.epilogue_visitor = epilogue_visitor
#This works
plan.run(
tensor_A,
tensor_B,
tensor_C,
tensor_D,
visitor_args=visitor_args,
print_module=print_module,
)
binary_op = torch.mul
ref_D = binary_op(tensor_A @ tensor_B, scale)
print(f"ref_D =\n {ref_D}")
print(f"tensor_D =\n {tensor_D}")
print(f"(tensor_D - ref_D).abs().max() = {(tensor_D - ref_D).abs().max()}")
# Below does not work, set jit to False which shows the generated code, which is incorrect
op = plan.construct()
mod = cutlass.emit.pytorch(
op, name="epilogue_broadcast", cc=plan.cc, sourcedir="epilogue", jit=True
)
Expected behavior
Expect the jitted
pytorch module to work per the non-pytorch version (using plan.run
, which compiles and runs the kernel directly through pycuda
/ C
interface).
Environment details (please complete the following information):
- GPU:
A6000
-
nvidia-cutlass
:3.5.0
Additional Context
Below is the generated extension module (with jit
set to False
).
Issues:
- The code refers to
DeviceKernel
but none is generated - Even though the
EVT
is declared, none of the interface functions provide args for the visitor func
// This file was automatically generated by the CUTLASS 3.5.0 Python interface (https://github.com/nvidia/cutlass/python)
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
// helper function allocating the memory
void *device_memory_allocation(size_t size, int device_id = 0)
{
if (size > 0)
{
torch::Device device(torch::kCUDA, device_id);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
at::Tensor device_tensor = torch::empty({
(long)size,
},
options);
return reinterpret_cast<void *>(device_tensor.data_ptr());
}
else
{
return nullptr;
}
}
#include "cutlass/gemm/device/gemm_universal.h"
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::half_t,
8,
1 /* epilogue stages */
>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, cutlass::half_t,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, cutlass::half_t, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
Compute0,
Scale,
Accum>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
D,
EVTCompute0>;
// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8
using cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, 8,
float,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
EVTD,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd,
1 /* epilogue stages */
>::GemmKernel;
// Define named type
struct cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_type : public cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base
{
};
using DeviceKernel = cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_type;
using ElementCompute = typename DeviceKernel::ElementC;
cutlass::Status epilogue_broadcast_kernel_run(int M, int N, int K,
const DeviceKernel::ElementA *A, const DeviceKernel::ElementB *B, const DeviceKernel::ElementC *C, DeviceKernel::ElementC *D,
ElementCompute alpha, ElementCompute beta)
{
typename DeviceKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K}, // problem size
1,
{alpha, beta},
A,
B,
C,
D,
0,
0,
0,
0, // batch strides
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
};
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
DeviceKernel gemm_op;
cutlass::Status status = gemm_op.initialize(arguments,
workspace.get(),
nullptr); // CUDA stream
if (status != cutlass::Status::kSuccess)
{
return status;
}
status = gemm_op();
return status;
}
at::Tensor epilogue_broadcast_kernel(const at::Tensor &A, const at::Tensor &B, at::optional<const at::Tensor> C, float alpha, float beta)
{
int M = A.size(0);
int N = B.size(1);
int K = A.size(1);
typename DeviceKernel::ElementC *ptrC = (C == at::nullopt) ? nullptr : reinterpret_cast<typename DeviceKernel::ElementC *>(C->contiguous().data_ptr());
at::Tensor D = B.new_empty({M, N}, torch::kF16);
cutlass::Status status = epilogue_broadcast_kernel_run(M, N, K,
reinterpret_cast<typename DeviceKernel::ElementA *>(A.contiguous().data_ptr()),
reinterpret_cast<typename DeviceKernel::ElementB *>(B.contiguous().data_ptr()),
ptrC,
reinterpret_cast<typename DeviceKernel::ElementC *>(D.contiguous().data_ptr()),
ElementCompute(alpha), ElementCompute(beta));
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
We haven't yet done the plumbing to emit the correct EVT arguments structures for creating a PyTorch extension for a kernel that uses EVT. Apologies that this hasn't been better documented and lacks a clear error indicating the lack of support.
@jackkosaian Thanks for the response.
Are there any examples or documentation on how to properly construct arguments for an EVT
, other than the streamk example?
Moreover, I'm having trouble with the different epilogue interfaces, #1459, for a relatively simple example. Would appreciate any help!
This issue has been labeled inactive-30d
due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d
if there is no activity in the next 60 days.
This issue has been labeled inactive-90d
due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.