jittor icon indicating copy to clipboard operation
jittor copied to clipboard

modify stride positive check in jt.nn.conv_transpose3d/jt.nn.conv_transpose; add input shape check in jt.nn.conv_transpose3d/jt.nn.conv_transpose

Open fansunqi opened this issue 8 months ago • 0 comments

  1. modify stride positive check in jt.nn.conv_transpose3d/jt.nn.conv_transpose:

Previous:

if stride <= 0:
        raise RuntimeError("non-positive stride is not supported")
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)

will raise error when stride is a tuple.

After modification:

stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
if stride[0] <= 0 or stride[1] <= 0 or stride[2] <= 0:
       raise RuntimeError("non-positive stride is not supported")

can handle both cases when stride is a tuple or a single number.

  1. add input shape check in jt.nn.conv_transpose3d/jt.nn.conv_transpose

fansunqi avatar Jun 10 '24 08:06 fansunqi