torch2trt icon indicating copy to clipboard operation
torch2trt copied to clipboard

Writing custom log_softmax converter

Open gabe-scorebreak opened this issue 1 year ago • 0 comments

Hi, I want to convert torch.nn.functional.log_softmax to trt, however, some of my tests seem to fail and I don't know why.

This is what I came up with:

from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter("torch.Tensor.log_softmax")
@tensorrt_converter("torch.nn.functional.log_softmax")
def convert_log_softmax(ctx):
    input = ctx.method_args[0]
    input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
    output = ctx.method_return

    # get dims from args or kwargs
    if "dim" in ctx.method_kwargs:
        dim = ctx.method_kwargs["dim"]
    elif len(ctx.method_args) >= 2:
        dim = ctx.method_args[1]

    # convert negative dims
    if dim < 0:
        dim = len(input.shape) + dim

    axes = torch_dim_to_trt_axes(dim)

    layer = ctx.network.add_softmax(input=input_trt)
    layer.axes = axes
    layer = ctx.network.add_unary(input=layer.get_output(0), op=trt.UnaryOperation.LOG)

    output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3)])
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module():
    return torch.nn.LogSoftmax(1)

@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module_dim2():
    return torch.nn.LogSoftmax(2)

@add_module_test(torch.float32, torch.device("cuda"), [(1, 3)])
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module_neg1():
    return torch.nn.LogSoftmax(-1)

@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 3, 3)])
def test_log_softmax_module_dim_neg2():
    return torch.nn.LogSoftmax(-2)

This is not very original, I just took the implementation of softmax and added the line layer = ctx.network.add_unary(input=layer.get_output(0), op=trt.UnaryOperation.LOG) which is what log softmax is supposed to do. Well, tests fail

|               torch2trt.converters.log_softmax.test_log_softmax_module | float32 |            [(1, 3, 3, 3)] | {} | 2.38E-07 | 155.45 | 3.77E-15 | 8.48e+04 | 1.19e+04 | 0.0893 | 0.178 |
|               torch2trt.converters.log_softmax.test_log_softmax_module | float32 |                  [(1, 3)] | {} | 1.99E+00 | 2.97 | 2.01E+00 | 8.02e+04 | 1.25e+04 | 0.0917 | 0.192 |
|          torch2trt.converters.log_softmax.test_log_softmax_module_dim2 | float32 |            [(1, 3, 3, 3)] | {} | 1.86E+00 | 11.72 | 8.45E-01 | 7.96e+04 | 1.16e+04 | 0.0918 | 0.188 |
|          torch2trt.converters.log_softmax.test_log_softmax_module_neg1 | float32 |            [(1, 3, 3, 3)] | {} | 2.13E+00 | 11.94 | 8.13E-01 | 7.93e+04 | 1.18e+04 | 0.0987 | 0.184 |
|          torch2trt.converters.log_softmax.test_log_softmax_module_neg1 | float32 |                  [(1, 3)] | {} | 3.49E+00 | 2.71 | 6.53E+00 | 8.4e+04 | 1.23e+04 | 0.0999 | 0.188 |
|      torch2trt.converters.log_softmax.test_log_softmax_module_dim_neg2 | float32 |            [(1, 3, 3, 3)] | {} | 1.45E+00 | 11.10 | 6.35E-01 | 7.47e+04 | 1.23e+04 | 0.111 | 0.191 |
NUM_TESTS: 6
NUM_SUCCESSFUL_CONVERSION: 6
NUM_FAILED_CONVERSION: 0
NUM_ABOVE_TOLERANCE: 5
NUM_pSNR_TOLERANCE: 0

Interestingly enough one tests passes, other five fail. What I'm worried about is that the errors are quite large. At first I thought maybe using some more clever formula like logsoftmax(x) = (x - x_max) - log(sum(exp((x - x_max)))) would perform better here, I tried my luck implementing it, but to no avail, still one test passes, the rest fails. But even in the current implementation I wouldn't expect such extreme errors, this tells me that something is wrong on a fundamental level. I would appreciate any help @jaybdub I will gladly create a PR if you help me get it working

gabe-scorebreak avatar May 30 '23 15:05 gabe-scorebreak