jittor
jittor copied to clipboard
modify stride positive check in jt.nn.conv_transpose3d/jt.nn.conv_transpose; add input shape check in jt.nn.conv_transpose3d/jt.nn.conv_transpose
- modify stride positive check in jt.nn.conv_transpose3d/jt.nn.conv_transpose:
Previous:
if stride <= 0:
raise RuntimeError("non-positive stride is not supported")
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
will raise error when stride is a tuple.
After modification:
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
if stride[0] <= 0 or stride[1] <= 0 or stride[2] <= 0:
raise RuntimeError("non-positive stride is not supported")
can handle both cases when stride is a tuple or a single number.
- add input shape check in jt.nn.conv_transpose3d/jt.nn.conv_transpose