AITemplate
AITemplate copied to clipboard
Does Conv2d Kernels Support Float32 on SM75?
I'm on SM75 RTX2070 with CUDA12.0.0. I took out a single conv2d part from the Resnet50 example and want to compile it into an AIT model with datatype float32. However, it gives the StopIteration error which seems like AIT doesn't find a kernel for it. The code I executed is:
import numpy as np
from aitemplate.frontend import nn as aitnn
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
from aitemplate.compiler import compile_model
def mark_output(y):
"""Different to PyTorch, we need to explicit mark output tensor for optimization,
Parameters
----------
y : List[Tensor]
List of output tensors
"""
if type(y) is not tuple:
y = (y,)
for i in range(len(y)):
y[i]._attrs["is_output"] = True
y[i]._attrs["name"] = "output_%d" % (i)
y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]]
#print("output_{} shape: {}".format(i, y_shape))
class BasicStem(aitnn.Module):
"""
The standard ResNet stem (layers before the first residual block),
with a conv, relu and max_pool.
"""
def __init__(self, in_channels=4, out_channels=8, kernel_size=7, auto_padding=True, dtype='float16', norm="BN", activation="ReLU"):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = 2
conv_op = aitnn.Conv2dBiasRelu
self.conv1 = conv_op(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=3 // 2,
dtype=dtype)
# self.pool = aitnn.MaxPool2d(3, 2, 1)
def forward(self, x):
x = self.conv1(x)
# x = self.pool(x)
return x
def compile_module(model_name, batch_size, **kwargs):
model_name = f"{model_name}_{batch_size}"
target = detect_target(**kwargs) # get a CUDA Target
# Create input tensor, need to specify the shape, dtype and is_input flag
x = Tensor( # AITemplate Tensor
shape=[batch_size, 8, 8, kwargs['in_channels']], dtype=kwargs['dtype'], name="input0", is_input=True
)
model = BasicStem(
in_channels=kwargs['in_channels'],
out_channels=kwargs['out_channels'],
dtype=kwargs['dtype'])
# Mark all parameters with name same to PyTorch name convention
model.name_parameter_tensor()
# Forward the input tensor to the model, get output tensor
y = model(x)
# Mark output tensor
mark_output(y)
# Compile the model
module = compile_model(y, target, "./tmp", model_name)
return module
# >>> Main Function >>>
if __name__ == "__main__":
# Test normal kernels
failed_kernels = []
dtype='float32'
CI=16
print(f"test conv kernel compilation size {CI}")
model = compile_module(
model_name="fp32_conv_{}".format(i),
batch_size=8,
in_channels=CI,
out_channels=16,
dtype=dtype,
kernel_size=4,
use_fp16_acc=False)
And the error report:
INFO:aitemplate.backend.build_cache_base:Build cache disabled
test conv kernel compilation size 16
2023-07-11 19:19:35,265 INFO <aitemplate.testing.detect_target> Set target to CUDA
75
2023-07-11 19:19:35,270 INFO <aitemplate.backend.target> Loading profile cache from: /home/ioeddk/.aitemplate/cuda.db
2023-07-11 19:19:35,275 INFO <aitemplate.backend.profiler_cache> table_name='cuda_gemm_3' exists in the db
2023-07-11 19:19:35,275 INFO <aitemplate.backend.profiler_cache> table_name='cuda_conv_3' exists in the db
2023-07-11 19:19:35,275 INFO <aitemplate.backend.profiler_cache> table_name='cuda_conv3d_3' exists in the db
Culled cutlass_simt_cf32_cfprop_analytic_cf32_128x128_8x5_nhwc_align1 from manifest
Culled cutlass_simt_cf32_cfprop_optimized_cf32_128x128_8x5_nhwc_align1 from manifest
Culled cutlass_simt_cf32_cdgrad_analytic_cf32_128x128_8x5_nhwc_unity_stride_align1 from manifest
Culled cutlass_simt_cf32_cdgrad_optimized_cf32_128x128_8x5_nhwc_unity_stride_align1 from manifest
Culled cutlass_simt_cf32_cdgrad_analytic_cf32_128x128_8x5_nhwc_align1 from manifest
Culled cutlass_simt_cf32_cdgrad_optimized_cf32_128x128_8x5_nhwc_align1 from manifest
Culled cutlass_simt_cf32_cwgrad_analytic_cf32_128x128_8x5_nhwc_align1 from manifest
Culled cutlass_simt_cf32_cwgrad_optimized_cf32_128x128_8x5_nhwc_align1 from manifest
2023-07-11 19:19:36,131 WARNING <aitemplate.backend.cuda.utils> Arch 75 is not supported by extra ops.
75
75
2023-07-11 19:19:36,138 INFO <aitemplate.compiler.compiler> optimized graph elapsed time: 0:00:00.005867
2023-07-11 19:19:36,138 INFO <aitemplate.compiler.transform.refine_graph> reduced unique ops from 1 to 1
2023-07-11 19:19:36,140 INFO <aitemplate.utils.environ> force_cache=False
Traceback (most recent call last):
File "/home/ioeddk/GitHub/torchsparse-misc/AITemplate/tests/test_conv2d_fp32.py", line 87, in <module>
model = compile_module(
File "/home/ioeddk/GitHub/torchsparse-misc/AITemplate/tests/test_conv2d_fp32.py", line 73, in compile_module
module = compile_model(y, target, "./tmp", model_name)
File "/home/ioeddk/venv/full/lib/python3.10/site-packages/aitemplate/compiler/compiler.py", line 274, in compile_model
compiler.transform.profile(
File "/home/ioeddk/venv/full/lib/python3.10/site-packages/aitemplate/compiler/transform/profile.py", line 81, in profile
codegen.gen_profiler(sorted_graph, profiler_dir, dynamic_profiling_strategy)
File "/home/ioeddk/venv/full/lib/python3.10/site-packages/aitemplate/backend/codegen.py", line 82, in gen_profiler
func.gen_profiler(workdir, dynamic_profiling_strategy)
File "/home/ioeddk/venv/full/lib/python3.10/site-packages/aitemplate/compiler/ops/conv/conv2d.py", line 438, in gen_profiler
if self._should_build_profiler():
File "/home/ioeddk/venv/full/lib/python3.10/site-packages/aitemplate/compiler/ops/conv/conv2d.py", line 376, in _should_build_profiler
tmp_key = next(iter(self._attrs["op_instance"].keys()))
StopIteration
However, it works as expected for dtype='float16'
.
@ioeddk thanks for the question. Indeed, conv2d
operators support dtype="float32"
from SM80 (and above). The constraint comes from the underlying CUTLASS conv (implicit gemm) kernels and can be confirmed by this line in the test_conv_bias
unit test. I hope, this answers your question.
Cutlass has simt conv kernel for sm75
@hwu36 thanks for the clarification. I believe, the Simt kernels end up being excluded from profiling. Perhaps @chenyang78 could provide more context here?
If CUTLASS has simt conv kernel for SM75, can we somehow add this option for AITemplate by some change in the kernel selection code?
@hwu36 thanks for the clarification. I believe, the Simt kernels end up being excluded from profiling. Perhaps @chenyang78 could provide more context here?
We support float for conv2d, but not for conv2d with activations.
https://github.com/facebookincubator/AITemplate/blob/main/python/aitemplate/backend/cuda/conv2d/common_conv2d_bias_add_activation.py#L48
I don't follow conv2d-family kernels closely, so I don't recall why simt was disabled there.
After tweak it to skip_simt_kernels=False
, there is still StopIteration exception when using conv2d, conv2dBias, and conv2dBiasReLU for float32
. Seems like there are other parts filtering it out?
Cutlass has simt conv kernel for sm75
Hmm, @hwu36 I think I must be missing something. From the link below:
https://github.com/NVIDIA/cutlass/blob/main/tools/library/scripts/generator.py#L1285-L1295
Seems we don't have simt kernels for SM75. Am I looking at the wrong place? Thanks!