AITemplate
AITemplate copied to clipboard
ops.conv2d(group=256) outputs NaN and Inf
This is the UnitTest
import unittest
import torch
from aitemplate.compiler import compile_model, ops
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
class ConvGroupTestCase(unittest.TestCase):
def test_fp16(self):
groups = 256 # if changed to 1 this passes
size = (12,12)
target = detect_target()
X = Tensor(
shape=[1, *size, 256],
dtype="float16",
name="input_0",
is_input=True,
)
W = Tensor(
shape=[256, 3, 3, 256//groups], dtype="float16", name="input_1", is_input=True
)
OP = ops.conv2d(stride=1, pad=1, dilate=1, group=groups)
Y = OP(X, W)
Y._attrs["name"] = "output_0"
Y._attrs["is_output"] = True
module = compile_model(Y, target, "./output", "conv2dgroup")
X_pt = torch.randn(1, 256, *size).cuda().half()
W_pt = torch.randn(256, 256//groups, 3, 3).cuda().half()
Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1, groups=groups)
x = X_pt.permute((0, 2, 3, 1)).contiguous()
w = W_pt.permute((0, 2, 3, 1)).contiguous()
y = torch.empty([1, *size, 256]).cuda().half()
module.run_with_tensors({"input_0": x, "input_1": w}, [y])
y_transpose = y.permute((0, 3, 1, 2))
self.assertFalse(y_transpose.isnan().any())
self.assertFalse(y_transpose.isinf().any())
if target.name() == "cuda":
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2))
else:
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1))
if __name__ == "__main__":
torch.manual_seed(0)
unittest.main()
The AIT output contains many NaN and zeros
I noticed there was no tests for ops.conv2d testing group. Also, PyTorch conv2d calls this groups while it is called group here
I see conv2d_0.cu in importing "cutlass/conv/kernel/default_conv2d_fprop.h" but I think it should import and use "default_conv2d_group_fprop.h". So this must be related to CodeGen
Group (especially depth) conv is not fully supported in this release with CUDA backend, mainly due to depthwise conv is SIMT workload rather than TensorCore workload.
In v0.2 will bring these features. Before that we probably will gate these cases.
Sorry for missing documentation/gate function on this.
On Wed, Nov 2, 2022 at 17:39 Ehsan Azar @.***> wrote:
I see conv2d_0.cu in importing "cutlass/conv/kernel/default_conv2d_fprop.h" but I think it should import and use "default_conv2d_group_fprop.h". So this must be related to CodeGen
— Reply to this email directly, view it on GitHub https://github.com/facebookincubator/AITemplate/issues/68#issuecomment-1301524975, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTLXXY62SSNYKVCDPPLTTWGMCVXANCNFSM6AAAAAARVTSTCA . You are receiving this because you are subscribed to this thread.Message ID: @.***>
-- Bing Xu
Depth conv will remain to be simt kernel in cutlass 2.11. But the perf will be much better.
Created a PR mostly to get feedback. Added depthwise convolution (available in cutlass 2.10 as analytic). The test passes. Would appreciate any feedback.
I just saw there is a PR with "depthwise_conv3d". Any change v0.11 also adds depthwise_conv2d?