AITemplate icon indicating copy to clipboard operation
AITemplate copied to clipboard

ops.conv2d(group=256) outputs NaN and Inf

Open dashesy opened this issue 1 year ago • 3 comments

This is the UnitTest

import unittest

import torch

from aitemplate.compiler import compile_model, ops
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target


class ConvGroupTestCase(unittest.TestCase):
    def test_fp16(self):
        groups = 256   # if changed to 1 this passes
        size = (12,12)
        target = detect_target()
        X = Tensor(
            shape=[1, *size, 256],
            dtype="float16",
            name="input_0",
            is_input=True,
        )
        W = Tensor(
            shape=[256, 3, 3, 256//groups], dtype="float16", name="input_1", is_input=True
        )
        OP = ops.conv2d(stride=1, pad=1, dilate=1, group=groups)
        Y = OP(X, W)
        Y._attrs["name"] = "output_0"
        Y._attrs["is_output"] = True
        module = compile_model(Y, target, "./output", "conv2dgroup")

        X_pt = torch.randn(1, 256, *size).cuda().half()
        W_pt = torch.randn(256, 256//groups, 3, 3).cuda().half()
        Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1, groups=groups)
        x = X_pt.permute((0, 2, 3, 1)).contiguous()
        w = W_pt.permute((0, 2, 3, 1)).contiguous()
        y = torch.empty([1, *size, 256]).cuda().half()
        module.run_with_tensors({"input_0": x, "input_1": w}, [y])
        y_transpose = y.permute((0, 3, 1, 2))
        self.assertFalse(y_transpose.isnan().any())
        self.assertFalse(y_transpose.isinf().any())
        if target.name() == "cuda":
            self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2))
        else:
            self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1))

if __name__ == "__main__":
    torch.manual_seed(0)
    unittest.main()

The AIT output contains many NaN and zeros

I noticed there was no tests for ops.conv2d testing group. Also, PyTorch conv2d calls this groups while it is called group here

dashesy avatar Nov 02 '22 23:11 dashesy

I see conv2d_0.cu in importing "cutlass/conv/kernel/default_conv2d_fprop.h" but I think it should import and use "default_conv2d_group_fprop.h". So this must be related to CodeGen

dashesy avatar Nov 03 '22 00:11 dashesy

Group (especially depth) conv is not fully supported in this release with CUDA backend, mainly due to depthwise conv is SIMT workload rather than TensorCore workload.

In v0.2 will bring these features. Before that we probably will gate these cases.

Sorry for missing documentation/gate function on this.

On Wed, Nov 2, 2022 at 17:39 Ehsan Azar @.***> wrote:

I see conv2d_0.cu in importing "cutlass/conv/kernel/default_conv2d_fprop.h" but I think it should import and use "default_conv2d_group_fprop.h". So this must be related to CodeGen

— Reply to this email directly, view it on GitHub https://github.com/facebookincubator/AITemplate/issues/68#issuecomment-1301524975, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTLXXY62SSNYKVCDPPLTTWGMCVXANCNFSM6AAAAAARVTSTCA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

-- Bing Xu

antinucleon avatar Nov 03 '22 00:11 antinucleon

Depth conv will remain to be simt kernel in cutlass 2.11. But the perf will be much better.

hwu36 avatar Nov 03 '22 00:11 hwu36

Created a PR mostly to get feedback. Added depthwise convolution (available in cutlass 2.10 as analytic). The test passes. Would appreciate any feedback.

I just saw there is a PR with "depthwise_conv3d". Any change v0.11 also adds depthwise_conv2d?

dashesy avatar Nov 07 '22 23:11 dashesy