TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

🐛 [Bug] non-contiguous tensor with GroupNorm's affine=False causes shape change error

Open Jason3900 opened this issue 1 year ago • 1 comments

Bug Description

non-contiguous tensor with GroupNorm's affine=False causes the following error:

DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node norm_fn/native_group_norm (kind: aten.native_group_norm.default, args: ('clone <Node>', 'None <NoneType>', 'None <NoneType>', '5 <int>', '512 <int>', '256 <int>', '32 <int>', '1e-05 <float>'))
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: Shape Error (reshape changes volume. Reshaping [1] to [1,512,1,1].)
ERROR:torch_tensorrt [TensorRT Conversion Context]:ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [SHUFFLE]-[aten_ops.native_group_norm.default]-[norm_fn/native_group_norm_reshape_gamma].)

To Reproduce

Steps to reproduce the behavior:

class GroupNormSpatial(nn.Module):
    """GroupNorm with spatial dimensions ignored."""
    def __init__(self, num_groups, num_channels, epsilon: float = 1e-5, affine=True):
        super().__init__()
        self.norm_fn = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=epsilon, affine=affine)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:

        inputs = rearrange(inputs, "b c t h w -> (b t) c h w") # this will produce non-contiguous tensor, add add .contiguous() doesn't work when compiling torch_tensorrt model
        out = self.norm_fn(inputs) # this will raise the error above
        out = rearrange(out, "(b t) c h w -> b c t h w", b=b, t=t)
        return out


model = GroupNormSpatial(num_groups=32, num_channels=512, affine=False)

trt_model = torch_tensorrt.compile(model,
      inputs= [torch_tensorrt.Input((1, 512, 5, 16, 16))],
      debug=True,
      ir="dynamo",
      enabled_precisions= {torch.float32},
      make_refitable=True
)

Environment

I use ngc torch image: nvcr.io/nvidia/pytorch:24.10-py3 torch_tensorrt=2.5.0a0

Jason3900 avatar Dec 15 '24 15:12 Jason3900

Taking a look

apbose avatar Jan 15 '25 01:01 apbose