TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

How can I make a dynamic input slice in TensorRT?

Open OlegSkuibida opened this issue 2 years ago • 3 comments

I want to create the pipeline for the object counting system: Yolov7 model with NMS plugin (one picture in batch) -> dynamic slice -> CropAndResize plugin -> feature extract model (batch on extract objects). All ONNX posibilities does not work in TensorRT only with slice size declared as Constant. Maybe there is some plugin?

OlegSkuibida avatar Aug 30 '22 11:08 OlegSkuibida

I don't quite get what your problem is, can you elaborate on it?

zerollzeng avatar Aug 31 '22 10:08 zerollzeng

I don't quite get what your problem is, can you elaborate on it?

such as the code

            # output [1, 8400, 85]
            # slice boxes, obj_score, class_scores
            strides = trt.Dims([1,1,1])
            starts = trt.Dims([0,0,0])
            bs, num_boxes, temp = previous_output.shape
            shapes = trt.Dims([bs, num_boxes, 4])
            # [0, 0, 0] [1, 8400, 4] [1, 1, 1]
            boxes = self.network.add_slice(previous_output, starts, shapes, strides)
            num_classes = temp -5 
            starts[2] = 4
            shapes[2] = 1
            # [0, 0, 4] [1, 8400, 1] [1, 1, 1]
            obj_score = self.network.add_slice(previous_output, starts, shapes, strides)
            starts[2] = 5
            shapes[2] = num_classes
            # [0, 0, 5] [1, 8400, 80] [1, 1, 1]
            scores = self.network.add_slice(previous_output, starts, shapes, strides)
            # scores = obj_score * class_scores => [bs, num_boxes, nc]
            updated_scores = self.network.add_elementwise(obj_score.get_output(0), scores.get_output(0), trt.ElementWiseOperation.PROD)

When I try to use dynamic batch, bs is -1, then slice layer will report an error

Linaom1214 avatar Sep 06 '22 14:09 Linaom1214

It's hard to see the problem from the code snippet. Can you provide the full reproduce code here? or an onnx model that can reproduce this error(preferred).

zerollzeng avatar Sep 06 '22 14:09 zerollzeng

Model yolov7-w6-person.onnx link Last layer EfficientNMS plugin I want to slice all outputs for model by network.get_output(0)[0][0]

def add_slice(network):
    # model yolov7-w6-person.onnx https://drive.google.com/file/d/1-N_7zBvS05gNzAl8abHZHdNsjd1fAFNr/view?usp=sharing
    # input shape 1, 3, 640, 640
    #outputs
    bs = 1
    # num_boxes static = 48 
    # num_dets shape [bs, 1]
    # det_boxes shape [bs, num_boxes, 4]
    # det_scores shape [bs, num_boxes]
    # det_classes shape [bs, num_boxes]    
    strides = trt.Dims([1,1,1])
    starts = trt.Dims([0,0,0])
    num_dets =  network.get_output(0)[0][0]
    #!!!!!!!!!! TypeError: 'tensorrt.tensorrt.ITensor' object is not subscriptable
    print("num_dets", num_dets)
    for i in range(1, network.num_outputs):
        previous_output = network.get_output(i)       
        print(i, previous_output.name)
        bs, num_boxes, temp = previous_output.shape
        if i == 1:
            shapes = trt.Dims([bs, num_dets, 4])
        else:
            shapes = trt.Dims([bs, num_dets])
        network.add_slice(previous_output, starts, shapes, strides)    
    return network

OlegSkuibida avatar Sep 26 '22 09:09 OlegSkuibida

@pranavm-nvidia Can you help ^ ^

zerollzeng avatar Sep 26 '22 14:09 zerollzeng

You can use slice.set_input(index, tensor) for dynamic slice. See the documentation for details.

To create the shape, you should be able to use the IShapeLayer in combination with layers like gather, elementwise, and concat.

pranavm-nvidia avatar Sep 26 '22 17:09 pranavm-nvidia

You can use slice.set_input(index, tensor) for dynamic slice. See the documentation for details.

Thank you. I saw it but could not find any examples how use it. Could you please provide some examples?

OlegSkuibida avatar Sep 26 '22 17:09 OlegSkuibida

I wrote up a short example, let me know if it helps:

#!/usr/bin/env python3
# Generation Command: polygraphy template trt-network -o use_dynamic_slice.py
import numpy as np
import tensorrt as trt
from polygraphy import func
from polygraphy.backend.trt import CreateNetwork

"""
Builds a network with a dynamic slice layer.
Also passes custom input data so it's easy to see the behavior of the network.

Run with:
polygraphy run use_dynamic_slice.py --trt --input-shapes input:[1,3,4,4] --data-loader-script use_dynamic_slice.py -vv
"""


@func.extend(CreateNetwork())
def load_network(builder, network):
    inp = network.add_input("input", shape=(1, 3, -1, -1), dtype=trt.float32)

    # Fill in start, shape, and stride with some sane defaults.
    # Later, we'll replace these with input tensors to make the slice dynamic.
    slice = network.add_slice(inp, start=(0, 0, 0, 0), shape=(1, 3, 1, 1), stride=(1, 1, 1, 1))

    # Next, let's imagine we want to take the top left quadrant of the image.
    # That is, for an image with dimensions (N, C, H, W), we want the first (N, C, H / 2, W / 2) pixels.

    # We'll combine an IShapeLayer with an IElementWiseLayer to compute the new shape.
    shape = network.add_shape(inp)
    divisor = network.add_constant(shape=(4,), weights=np.array([1, 1, 2, 2], dtype=np.int32))
    new_shape = network.add_elementwise(shape.get_output(0), divisor.get_output(0), trt.ElementWiseOperation.DIV)

    # Now that we know the new shape, let's add an input tensor for `shape`, which is at index 2.
    # Refer to the API documentation for details:
    # https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html?highlight=islicelayer#tensorrt.ISliceLayer.set_input.
    slice.set_input(2, new_shape.get_output(0))

    network.mark_output(slice.get_output(0))


def load_data():
    return [{"input": np.arange(1 * 3 * 4 * 4, dtype=np.float32).reshape(1, 3, 4, 4)}]

You can save this as use_dynamic_slice.py and then follow the instructions in the docstring if you want to run it locally.

pranavm-nvidia avatar Sep 26 '22 18:09 pranavm-nvidia

Thank you I think it will help I will try tomorrow with my code I have some troubles with yours

root@b1dc414c852e:/models# polygraphy run use_dynamic_slice.py --trt --input-shapes input:[1,3,4,4] --data-loader-script use_dynamic_slice.py -vv
[V] Loaded Module: polygraphy.util    | Path: ['/usr/local/lib/python3.8/dist-packages/polygraphy/util']
[V] Model: use_dynamic_slice.py
[V] Loaded Module: polygraphy         | Version: 0.38.0 | Path: ['/usr/local/lib/python3.8/dist-packages/polygraphy']
[V] Loaded extension modules: []
[V] Loaded Module: tensorrt           | Version: 8.4.2.4 | Path: ['/usr/local/lib/python3.8/dist-packages/tensorrt']
[!] Could not import symbol: load_data from script: use_dynamic_slice.py
    Note: Error was: module 'use_dynamic_slice' has no attribute 'load_data'
    Note: sys.path was: ['', '/usr/local/bin', '/usr/lib/python38.zip', '/usr/lib/python3.8', '/usr/lib/python3.8/lib-dynload', '/usr/local/lib/python3.8/dist-packages', '/usr/lib/python3/dist-packages']
root@b1dc414c852e:/models#

OlegSkuibida avatar Sep 26 '22 19:09 OlegSkuibida

Weird, maybe you missed the last two lines in the script?

def load_data():
    return [{"input": np.arange(1 * 3 * 4 * 4, dtype=np.float32).reshape(1, 3, 4, 4)}]

Otherwise you can omit the custom data:

polygraphy run use_dynamic_slice.py --trt --input-shapes input:[1,3,4,4]  -vv

pranavm-nvidia avatar Sep 26 '22 19:09 pranavm-nvidia

Thanks ;) I did it

#!/usr/bin/env python3
import numpy as np
import tensorrt as trt
from polygraphy.backend.trt import *

def add_slice(builder, network):
    inp0 = network.add_input("num_dets", shape=(3,), dtype=trt.int32)
    inp1 = network.add_input("det_boxes", shape=(1, 32, 4), dtype=trt.float32)
    slice = network.add_slice(inp1, start=(0, 0, 0), shape=(1, 1, 4), stride=(1, 1, 1))     
    slice.set_input(2, inp0)
    network.mark_output(slice.get_output(0))
    network.get_output(0).name = 'output'
    profile = Profile().add('num_dets', min=(1, 1, 4), opt=(1, 16, 4), max=(1, 32, 4))
    config = CreateConfig(profiles=[profile])
    engine = EngineFromNetwork((builder, network),config)()    
    SaveEngine(engine,'test.engine')()
    return engine 

if __name__ == "__main__":
    builder, network = CreateNetwork()()
    engine = add_slice(builder, network)
    print("Engine created")
    num_dets_0 = np.array([1,1,4]).astype(np.int32)    
    num_dets_1 = np.array([1,3,4]).astype(np.int32)
    det_boxes = np.random.rand(1, 32, 4).astype(np.float32)
    print('det_boxes input shape:',det_boxes.shape)
    with TrtRunner(engine) as runner:
        out0 = runner.infer({"num_dets":num_dets_0, 'det_boxes':det_boxes})    
        print(f'0 input slice shape {num_dets_0} out 0 shape:{out0["output"].shape}')
        out1 = runner.infer({"num_dets":num_dets_1, 'det_boxes':det_boxes})    
        print(f'1 input slice shape {num_dets_1} out 1 shape:{out1["output"].shape}')

OlegSkuibida avatar Sep 27 '22 23:09 OlegSkuibida

Updated version 'num_dets' as a scalar

#!/usr/bin/env python3
import numpy as np
import tensorrt as trt
from polygraphy.backend.trt import *

num_boxes = 32

def add_slice(builder, network):
    # Init inputs
    inp0 = network.add_input("num_dets", shape=(1,), dtype=trt.int32)
    inp1 = network.add_input("det_boxes", shape=(1, num_boxes, 4), dtype=trt.float32)
    # Create static slice
    slice = network.add_slice(inp1, start=(0, 0, 0), shape=(1, 1, 4), stride=(1, 1, 1))
    # Create a new shape for the slice layer
    mask = network.add_constant(shape=(3,), weights=np.array([0, 1, 0], dtype=np.int32))
    add = network.add_constant(shape=(3,), weights=np.array([1, 0, 4], dtype=np.int32))    
    shape_0 = network.add_elementwise(mask.get_output(0), inp0, trt.ElementWiseOperation.PROD)
    shape = network.add_elementwise(shape_0.get_output(0), add.get_output(0), trt.ElementWiseOperation.SUM)
    # Init dynamic shape
    slice.set_input(2, shape.get_output(0))
    # Create engine
    network.mark_output(slice.get_output(0))
    network.get_output(0).name = 'output'
    profile = Profile().add('num_dets', min=[1], opt=[num_boxes//2], max=[num_boxes])    
    config = CreateConfig(profiles=[profile])
    engine = EngineFromNetwork((builder, network),config)()   
    SaveEngine(engine,'test2.engine')()
    return engine 

if __name__ == "__main__":
    builder, network = CreateNetwork()()
    engine = add_slice(builder, network)
    print("Engine created")
    num_dets_0 = np.array([1]).astype(np.int32)    
    num_dets_1 = np.array([8]).astype(np.int32)
    det_boxes = np.random.rand(1, num_boxes, 4).astype(np.float32)
    print('det_boxes input shape:',det_boxes.shape)
    with TrtRunner(engine) as runner:
        out0 = runner.infer({"num_dets":num_dets_0, 'det_boxes':det_boxes})    
        print(f'0 input slice shape {num_dets_0} out 0 shape:{out0["output"].shape}')
        print(out0["output"])
        out1 = runner.infer({"num_dets":num_dets_1, 'det_boxes':det_boxes})    
        print(f'1 input slice shape {num_dets_1} out 1 shape:{out1["output"].shape}')                              
        print(out1["output"])

OlegSkuibida avatar Sep 28 '22 12:09 OlegSkuibida