TensorRT
TensorRT copied to clipboard
🐛 [Bug] acc tracer doesn't handle torch.max(tensor).values correctly
Bug Description
acc tracer doesn't handle torch.max(tensor).values correctly
To Reproduce
import torch
import torch.nn as nn
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
device = torch.device("cuda")
class MyModule(nn.Module):
def forward(self, x):
a = torch.max(torch.abs(x), dim=1).values
return a
# create an instance of the module
module = MyModule().to(device)
input_data = torch.randn(10, 10, device=device)
acc_tracer.trace(module, [input_data])
Error:
AssertionError:Expected torch.Tensor type for <class 'torch.return_types.max'>
If I switch to fx tracer, it works pretty fine.
Expected behavior
It runs okay.
Environment
trunk