TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

🐛 [Bug] acc tracer doesn't handle torch.max(tensor).values correctly

Open henrylhtsang opened this issue 4 months ago • 2 comments

Bug Description

acc tracer doesn't handle torch.max(tensor).values correctly

To Reproduce

import torch
import torch.nn as nn
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer

device = torch.device("cuda")

class MyModule(nn.Module):
    def forward(self, x):
        a = torch.max(torch.abs(x), dim=1).values
        return a

# create an instance of the module
module = MyModule().to(device)

input_data = torch.randn(10, 10, device=device)
acc_tracer.trace(module, [input_data])

Error:

AssertionError:Expected torch.Tensor type for <class 'torch.return_types.max'>

If I switch to fx tracer, it works pretty fine.

Expected behavior

It runs okay.

Environment

trunk

henrylhtsang avatar Sep 30 '24 18:09 henrylhtsang