torch2trt
torch2trt copied to clipboard
Writing custom log_softmax converter
Hi, I want to convert torch.nn.functional.log_softmax
to trt, however, some of my tests seem to fail and I don't know why.
This is what I came up with:
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test
@tensorrt_converter("torch.Tensor.log_softmax")
@tensorrt_converter("torch.nn.functional.log_softmax")
def convert_log_softmax(ctx):
input = ctx.method_args[0]
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
output = ctx.method_return
# get dims from args or kwargs
if "dim" in ctx.method_kwargs:
dim = ctx.method_kwargs["dim"]
elif len(ctx.method_args) >= 2:
dim = ctx.method_args[1]
# convert negative dims
if dim < 0:
dim = len(input.shape) + dim
axes = torch_dim_to_trt_axes(dim)
layer = ctx.network.add_softmax(input=input_trt)
layer.axes = axes
layer = ctx.network.add_unary(input=layer.get_output(0), op=trt.UnaryOperation.LOG)
output._trt = layer.get_output(0)
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3)])
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module():
return torch.nn.LogSoftmax(1)
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module_dim2():
return torch.nn.LogSoftmax(2)
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3)])
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module_neg1():
return torch.nn.LogSoftmax(-1)
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module_dim_neg2():
return torch.nn.LogSoftmax(-2)
This is not very original, I just took the implementation of softmax
and added the line layer = ctx.network.add_unary(input=layer.get_output(0), op=trt.UnaryOperation.LOG)
which is what log softmax is supposed to do. Well, tests fail
| torch2trt.converters.log_softmax.test_log_softmax_module | float32 | [(1, 3, 3, 3)] | {} | 2.38E-07 | 155.45 | 3.77E-15 | 8.48e+04 | 1.19e+04 | 0.0893 | 0.178 |
| torch2trt.converters.log_softmax.test_log_softmax_module | float32 | [(1, 3)] | {} | 1.99E+00 | 2.97 | 2.01E+00 | 8.02e+04 | 1.25e+04 | 0.0917 | 0.192 |
| torch2trt.converters.log_softmax.test_log_softmax_module_dim2 | float32 | [(1, 3, 3, 3)] | {} | 1.86E+00 | 11.72 | 8.45E-01 | 7.96e+04 | 1.16e+04 | 0.0918 | 0.188 |
| torch2trt.converters.log_softmax.test_log_softmax_module_neg1 | float32 | [(1, 3, 3, 3)] | {} | 2.13E+00 | 11.94 | 8.13E-01 | 7.93e+04 | 1.18e+04 | 0.0987 | 0.184 |
| torch2trt.converters.log_softmax.test_log_softmax_module_neg1 | float32 | [(1, 3)] | {} | 3.49E+00 | 2.71 | 6.53E+00 | 8.4e+04 | 1.23e+04 | 0.0999 | 0.188 |
| torch2trt.converters.log_softmax.test_log_softmax_module_dim_neg2 | float32 | [(1, 3, 3, 3)] | {} | 1.45E+00 | 11.10 | 6.35E-01 | 7.47e+04 | 1.23e+04 | 0.111 | 0.191 |
NUM_TESTS: 6
NUM_SUCCESSFUL_CONVERSION: 6
NUM_FAILED_CONVERSION: 0
NUM_ABOVE_TOLERANCE: 5
NUM_pSNR_TOLERANCE: 0
Interestingly enough one tests passes, other five fail. What I'm worried about is that the errors are quite large. At first I thought maybe using some more clever formula like logsoftmax(x) = (x - x_max) - log(sum(exp((x - x_max))))
would perform better here, I tried my luck implementing it, but to no avail, still one test passes, the rest fails. But even in the current implementation I wouldn't expect such extreme errors, this tells me that something is wrong on a fundamental level. I would appreciate any help @jaybdub
I will gladly create a PR if you help me get it working