TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

Unexpected result on multi-batch gather.

Open grimoire opened this issue 3 years ago • 4 comments

Description

Gathering topk index on a multi-batch tensor gives unexpected results. Note that if we replace the profile with:

    C=10
    input_shapes = {
        'input': {
            'min_shape': [1, C, 4],
            'opt_shape': [2, C, 4],
            'max_shape': [4, C, 4]
        }
    }

Given the right result.

Please read the code below for more detail.

Environment

TensorRT Version: 8.4.1.5 NVIDIA GPU: 2060s NVIDIA Driver Version: 510.85.02 CUDA Version: 11.3 CUDNN Version: 8.2.1 Operating System: Ubuntu18.04 Python Version (if applicable): 3.7 Tensorflow Version (if applicable): PyTorch Version (if applicable): 1.10.0 Baremetal or Container (if so, version):

Relevant Files

Steps To Reproduce

import torch
import tensorrt as trt
import onnx
from typing import Dict


def from_onnx(onnx_model, input_shapes, max_workspace_size):
    logger = trt.Logger(trt.Logger.INFO)
    builder = trt.Builder(logger)
    EXPLICIT_BATCH = 1 << (int)(
        trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = builder.create_network(EXPLICIT_BATCH)

    # parse onnx
    parser = trt.OnnxParser(network, logger)

    if isinstance(onnx_model, str):
        onnx_model = onnx.load(onnx_model)

    if not parser.parse(onnx_model.SerializeToString()):
        error_msgs = ''
        for error in range(parser.num_errors):
            error_msgs += f'{parser.get_error(error)}\n'
        raise RuntimeError(f'Failed to parse onnx, {error_msgs}')

    config = builder.create_builder_config()
    config.max_workspace_size = max_workspace_size

    profile = builder.create_optimization_profile()

    for input_name, param in input_shapes.items():
        min_shape = param['min_shape']
        opt_shape = param['opt_shape']
        max_shape = param['max_shape']
        profile.set_shape(input_name, min_shape, opt_shape, max_shape)
    config.add_optimization_profile(profile)

    engine = builder.build_engine(network, config)

    return engine


TORCH_DTYPE_MAP = {
    trt.bool: torch.bool,
    trt.int8: torch.int8,
    trt.int32: torch.int32,
    trt.float16: torch.float16,
    trt.float32: torch.float32
}


class TRTWrapper(torch.nn.Module):

    def __init__(self, engine: trt.ICudaEngine):
        super().__init__()
        self.engine = engine

        if not isinstance(self.engine, trt.ICudaEngine):
            raise TypeError(f'`engine` should be str or trt.ICudaEngine, \
                but given: {type(self.engine)}')

        self.context = self.engine.create_execution_context()
        self.__load_io_names()

    def __load_io_names(self):
        """Load input/output names from engine."""
        names = [_ for _ in self.engine]
        input_names = list(filter(self.engine.binding_is_input, names))
        self._input_names = input_names

        output_names = list(set(names) - set(input_names))
        self._output_names = output_names

    def forward(self, inputs: Dict[str,
                                   torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Run forward inference.

        Args:
            inputs (Dict[str, torch.Tensor]): The input name and tensor pairs.

        Return:
            Dict[str, torch.Tensor]: The output name and tensor pairs.
        """
        bindings = [None] * (len(self._input_names) + len(self._output_names))

        for input_name, input_tensor in inputs.items():
            idx = self.engine.get_binding_index(input_name)

            # All input tensors must be gpu variables
            input_tensor = input_tensor.contiguous()
            if input_tensor.dtype == torch.long:
                input_tensor = input_tensor.int()
            self.context.set_binding_shape(idx, tuple(input_tensor.shape))
            bindings[idx] = input_tensor.contiguous().data_ptr()

        # create output tensors
        outputs = {}
        for output_name in self._output_names:
            idx = self.engine.get_binding_index(output_name)
            dtype = TORCH_DTYPE_MAP[self.engine.get_binding_dtype(idx)]
            shape = tuple(self.context.get_binding_shape(idx))

            output = torch.empty(size=shape, dtype=dtype, device='cuda')
            outputs[output_name] = output
            bindings[idx] = output.data_ptr()

        self.context.execute_async_v2(bindings,
                                      torch.cuda.current_stream().cuda_stream)

        return outputs


class TestModel(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        batch_size = x.size(0)
        C = x.size(1)
        max_x, _ = x.max(-1)
        _, inds = max_x.topk(4)
        batch_inds = torch.arange(batch_size, device=inds.device).unsqueeze(-1)

        # new_x = torch.gather(x, 1, inds.unsqueeze(-1).expand(batch_size, 4, 4))
        new_x = x[batch_inds, inds, ...]
        # new_x = x.flatten(0, 1)[inds + batch_inds * C]
        return new_x, inds + batch_inds * C


def main():
    # models
    model = TestModel().cuda()
    x = torch.rand(1, 10, 4).cuda()

    # export onnx
    input_names = ['input']
    output_names = ['output', 'inds']
    torch.onnx.export(
        model,
        x,
        'tmp.onnx',
        input_names=input_names,
        output_names=output_names,
        dynamic_axes={'input': {
            0: 'b',
            1: 'n'
        }},
        opset_version=11)

    # export tensorrt
    input_shapes = {
        'input': {
            'min_shape': [1, 5, 4],
            'opt_shape': [2, 10, 4],
            'max_shape': [4, 40, 4]
        }
    }
    engine = from_onnx(
        'tmp.onnx', input_shapes=input_shapes, max_workspace_size=1 << 30)

    wrapper = TRTWrapper(engine)

    x = torch.rand(2, 10, 4).cuda()

    torch_out = model(x)
    out = wrapper({'input': x})
    out = [out[name] for name in output_names]

    # print(x)

    for o, to in zip(out, torch_out):
        print(o.shape)
        torch.testing.assert_allclose(o, to)

    # print(torch_out)


if __name__ == '__main__':
    main()

grimoire avatar Sep 06 '22 02:09 grimoire

Tried to reproduce the issue with TRT 8.4.1.5 using polygraphy:

[I] onnxrt-runner-N0-09/07/22-08:16:25  | Completed 1 iteration(s) in 0.1693 ms | Average inference time: 0.1693 ms.
[I] Accuracy Comparison | trt-runner-N0-09/07/22-08:16:25 vs. onnxrt-runner-N0-09/07/22-08:16:25
[I]     Comparing Output: 'output' (dtype=float32, shape=(2, 4, 4)) with 'output' (dtype=float32, shape=(2, 4, 4))
[I]     Tolerance: [abs=1e-05, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-09/07/22-08:16:25: output | Stats: mean=0.58187, std-dev=0.31482, var=0.099113, median=0.6921, min=0.039055 at (0, 2, 3), max=0.98886 at (1, 0, 0), avg-magnitude=0.58187
[I]         onnxrt-runner-N0-09/07/22-08:16:25: output | Stats: mean=0.58187, std-dev=0.31482, var=0.099113, median=0.6921, min=0.039055 at (0, 2, 3), max=0.98886 at (1, 0, 0), avg-magnitude=0.58187
[I]         Error Metrics: output
[I]             Minimum Required Tolerance: elemwise error | [abs=0] OR [rel=0] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0, 0), max=0 at (0, 0, 0), avg-magnitude=0
[I]             Relative Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0, 0), max=0 at (0, 0, 0), avg-magnitude=0
[I]         PASSED | Difference is within tolerance (rel=1e-05, abs=1e-05)
[I]     Comparing Output: 'inds' (dtype=int32, shape=(2, 4)) with 'inds' (dtype=int64, shape=(2, 4))
[I]     Tolerance: [abs=1e-05, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-09/07/22-08:16:25: inds | Stats: mean=10.375, std-dev=4.7942, var=22.984, median=9, min=5 at (0, 0), max=19 at (1, 2), avg-magnitude=10.375
[I]         onnxrt-runner-N0-09/07/22-08:16:25: inds | Stats: mean=10.375, std-dev=4.7942, var=22.984, median=9, min=5 at (0, 0), max=19 at (1, 2), avg-magnitude=10.375
[I]         Error Metrics: inds
[I]             Minimum Required Tolerance: elemwise error | [abs=0] OR [rel=0] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0), max=0 at (0, 0), avg-magnitude=0
[I]             Relative Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0), max=0 at (0, 0), avg-magnitude=0
[I]         PASSED | Difference is within tolerance (rel=1e-05, abs=1e-05)
[I]     PASSED | All outputs matched | Outputs: ['output', 'inds']
[I] PASSED | Command: /home/zeroz/.local/bin/polygraphy run tmp.onnx --trt --onnxrt --trt-opt-shapes input:[2,10,4] --input-shapes input:[2,10,4]

the accuracy is matched between TRT and ONNX. can you check whether it's matched between Torch and ONNX?

zerollzeng avatar Sep 07 '22 15:09 zerollzeng

The results matched between Torch and ONNX. Note that this error only appear when min_shape[0] != max_shape[0] and min_shape[1] != max_shape[1]

grimoire avatar Sep 08 '22 06:09 grimoire

I can reproduce this with

[I]         trt-runner-N0-09/09/22-00:16:40: output | Stats: mean=0.35972, std-dev=0.34652, var=0.12008, median=0.27958, min=0 at (1, 0, 0), max=0.96826 at (0, 0, 1), avg-magnitude=0.35972
[I]             ---- Histogram ----
                Bin Range        |  Num Elems | Visualization
                (0     , 0.0989) |         12 | ########################################
                (0.0989, 0.198 ) |          2 | ######
                (0.198 , 0.297 ) |          3 | ##########
                (0.297 , 0.396 ) |          2 | ######
                (0.396 , 0.494 ) |          3 | ##########
                (0.494 , 0.593 ) |          1 | ###
                (0.593 , 0.692 ) |          1 | ###
                (0.692 , 0.791 ) |          1 | ###
                (0.791 , 0.89  ) |          3 | ##########
                (0.89  , 0.989 ) |          4 | #############
[I]         onnxrt-runner-N0-09/09/22-00:16:40: output | Stats: mean=0.58187, std-dev=0.31482, var=0.099113, median=0.6921, min=0.039055 at (0, 2, 3), max=0.98886 at (1, 0, 0), avg-magnitude=0.58187
[I]             ---- Histogram ----
                Bin Range        |  Num Elems | Visualization
                (0     , 0.0989) |          3 | ###############
                (0.0989, 0.198 ) |          3 | ###############
                (0.198 , 0.297 ) |          2 | ##########
                (0.297 , 0.396 ) |          3 | ###############
                (0.396 , 0.494 ) |          2 | ##########
                (0.494 , 0.593 ) |          2 | ##########
                (0.593 , 0.692 ) |          1 | #####
                (0.692 , 0.791 ) |          5 | #########################
                (0.791 , 0.89  ) |          3 | ###############
                (0.89  , 0.989 ) |          8 | ########################################
[I]         Error Metrics: output
[I]             Minimum Required Tolerance: elemwise error | [abs=0.98886] OR [rel=1.1358] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.23223, std-dev=0.32724, var=0.10709, median=0.0025968, min=0 at (0, 0, 0), max=0.98886 at (1, 0, 0), avg-magnitude=0.23223
[I]                 ---- Histogram ----
                    Bin Range        |  Num Elems | Visualization
                    (0     , 0.0989) |         18 | ########################################
                    (0.0989, 0.198 ) |          3 | ######
                    (0.198 , 0.297 ) |          3 | ######
                    (0.297 , 0.396 ) |          0 |
                    (0.396 , 0.494 ) |          1 | ##
                    (0.494 , 0.593 ) |          0 |
                    (0.593 , 0.692 ) |          1 | ##
                    (0.692 , 0.791 ) |          3 | ######
                    (0.791 , 0.89  ) |          1 | ##
                    (0.89  , 0.989 ) |          2 | ####
[I]             Relative Difference | Stats: mean=0.39215, std-dev=0.46259, var=0.21399, median=0.0028745, min=0 at (0, 0, 0), max=1.1358 at (1, 1, 3), avg-magnitude=0.39215
[I]                 ---- Histogram ----
                    Bin Range      |  Num Elems | Visualization
                    (0    , 0.114) |         17 | ########################################
                    (0.114, 0.227) |          0 |
                    (0.227, 0.341) |          2 | ####
                    (0.341, 0.454) |          1 | ##
                    (0.454, 0.568) |          0 |
                    (0.568, 0.681) |          0 |
                    (0.681, 0.795) |          1 | ##
                    (0.795, 0.909) |          1 | ##
                    (0.909, 1.02 ) |          9 | #####################
                    (1.02 , 1.14 ) |          1 | ##
[E]         FAILED | Difference exceeds tolerance (rel=1e-05, abs=1e-05)
[I]     Comparing Output: 'inds' (dtype=int32, shape=(2, 4)) with 'inds' (dtype=int64, shape=(2, 4))
[I]     Tolerance: [abs=1e-05, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-09/09/22-00:16:40: inds | Stats: mean=10.375, std-dev=4.7942, var=22.984, median=9, min=5 at (0, 0), max=19 at (1, 2), avg-magnitude=10.375
[I]         onnxrt-runner-N0-09/09/22-00:16:40: inds | Stats: mean=10.375, std-dev=4.7942, var=22.984, median=9, min=5 at (0, 0), max=19 at (1, 2), avg-magnitude=10.375
[I]         Error Metrics: inds
[I]             Minimum Required Tolerance: elemwise error | [abs=0] OR [rel=0] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0), max=0 at (0, 0), avg-magnitude=0
[I]             Relative Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0), max=0 at (0, 0), avg-magnitude=0
[I]         PASSED | Difference is within tolerance (rel=1e-05, abs=1e-05)
[E]     FAILED | Mismatched outputs: ['output']
[!] FAILED | Command: /home/zeroz/.local/bin/polygraphy run tmp.onnx --trt --onnxrt --trt-opt-shapes input:[2,10,4] --trt-min-shapes input:[1,5,4] --trt-max-shapes input:[4,40,4] --input-shapes input:[2,10,4]

I've filed internal bug 3790543 to track this, thanks for reporting.

zerollzeng avatar Sep 09 '22 07:09 zerollzeng

The issue has been fixed in TRT 8.5, there will be a preview feature to fix this issue, please wait for the 8.5 release coming soon :-)

zerollzeng avatar Sep 20 '22 01:09 zerollzeng