TensorRT-SSD icon indicating copy to clipboard operation
TensorRT-SSD copied to clipboard

Inference

Open Maxfashko opened this issue 7 years ago • 13 comments

Hey. I implemented the inference using this code. But it seems that the boxes do not correspond to the actual location of objects in the image. I did not implement the softmax layer in the code, instead I used the usual softmаx caffe layer.

layer {
  name: "mbox_conf_softmax1"
  type: "Softmax"
  bottom: "mbox_conf_reshape"
  top: "mbox_conf_softmax1"
  softmax_param {
    axis: 2
  }
}

Is it necessary to use the softmax layer, implemented by myself?

example

Maxfashko avatar Feb 22 '18 09:02 Maxfashko

The Tensorrt softmax layer api dose not support axis:2. You must implement it by yourself.

chenzhi1992 avatar Feb 23 '18 01:02 chenzhi1992

Hi @chenzhi1992. I can not understand in what format the data to the enqueue function come to me. I realized that the data comes in an enqueue function in the form of a vector in which (for 21 classes of pascal voc) there are 12764 elements located in the vector sequentially. That is, in the function as an input parameter there is a pointer to a vector of 268044 elements.

//pluginimlement.cpp
int SoftmaxPlugin::enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream)
{
    float *top_data = reinterpret_cast<float*>(outputs[0]);
    const float *bottom_data = reinterpret_cast<const float*>(inputs[0]);
    SoftmaxLayer(size_array, bottom_data, top_data, stream);
    return 0;
}

//I defined the output of the layer as follows:

Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
    {
        assert(nbInputDims == 1);        
        assert(index == 0);
        assert(inputs[index].nbDims == 3);  
        return DimsCHW(inputs->d[0], inputs->d[1], inputs->d[2]);  // сохраняем параметры выхода
    }

void configure(const Dims*inputs, int nbInputs, const Dims* outputs, int nbOutputs, int)	override
    {
        mCopySize = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2] * sizeof(float);
        size_array = inputs[0].d[0]*inputs[0].d[1];
    }
//mathFunctions.h
cudaError_t SoftmaxLayer(int nthreads,
                         const float *bottom_data,
                         float *top_data,
                         cudaStream_t stream);

The kernel source code is simple, and I will not post it.

//mathFunction.cu

cudaError_t SoftmaxLayer(int nthreads,
                         const float *bottom_data,
                         float *top_data,
                         cudaStream_t stream)
{
    float *max_elem = new float;
    float *sum      = new float;

    // sarch max elem
    kernel_channel_max<float><<<TENSORRT_GET_BLOCKS(nthreads), TENSORRT_CUDA_NUM_THREADS,0,stream>>>(nthreads, bottom_data, *max_elem);

    // subtract max elem
    kernel_channel_subtract<float><<<TENSORRT_GET_BLOCKS(nthreads), TENSORRT_CUDA_NUM_THREADS,0,stream>>>(nthreads, bottom_data, *max_elem, top_data);

    // exp. The data from which the maximum element was subtracted are in top_data
    kernel_exp<float><<<TENSORRT_GET_BLOCKS(nthreads), TENSORRT_CUDA_NUM_THREADS,0,stream>>>(nthreads, top_data, top_data);

    // sum
    kernel_sum<float><<<TENSORRT_GET_BLOCKS(nthreads), TENSORRT_CUDA_NUM_THREADS,0,stream>>>(nthreads, top_data, *sum);

    delete max_elem;
    delete sum;

    return cudaPeekAtLastError();
}

After these conversions, it does not look like I see what I need at the output.

//main.cpp
// структура для хранения аттрибутов слоя-вывода
    struct detection_out_struct {
        float image_id, label, score, xmin, ymin, xmax, ymax;
    };

// наполнить вектор аттрибутами распознанного изображения
    void set_detection_out(){
        nvinfer1::DimsCHW dims = this->tensorNet.getTensorDims("detection_out");
        detection_out_struct * out = (detection_out_struct*) this->allocateMemory(dims);
        for(int i=0; i < this->tensorNet.getTensorDims("detection_out") ; i++){
            if (out[i].label != -1){
                this->detection_out.push_back(&out[i]);
            }
        }
    }

std::vector<detection_out_struct*> detection_out;   // вектор указателей на указатели структур аттрибутов распознанного изображения

then I draw all the boxes and see the result: img_screenshot_01 03 2018

Maxfashko avatar Mar 01 '18 04:03 Maxfashko

What is the ConcatPlugin class implemented for? This class is not used anywhere. Instead it is used createConcatPlugin declared NvInferPlugin.h.

else if (!strcmp(layerName, "mbox_loc"))
    {
        assert(mBox_loc_layer.get() == nullptr);
        mBox_loc_layer = std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)>
                (createConcatPlugin(1, true), nvPluginDeleter);
        return mBox_loc_layer.get();
    }

Maxfashko avatar Mar 01 '18 05:03 Maxfashko

  1. I think the implementation of your softmax layer code seems to be a bit of a problem, and you can refer to the code of caffe-ssd.
  2. In tensorrt3.0, the Concatenation layer links together multiple tensors of the same height and width across the channel dimension. So, I need to reimplement it when axis=2 or axis = 3.

chenzhi1992 avatar Mar 01 '18 06:03 chenzhi1992

@chenzhi1992 , I've reviewed the code caffe softmax layer https://github.com/BVLC/caffe/blob/master/src/caffe/layers/softmax_layer.cu. The main problem for me is how to set the axis for maximum summation and exponent operations. I understand correctly, I get a pointer to a vector of 1x268044? How is the axis defined in this case?

Maxfashko avatar Mar 01 '18 07:03 Maxfashko

Hi @Maxfashko, I am also stuck with the same problem. Were you able to fix it, and if so could you tell how?

GraphicsHunter avatar Mar 09 '18 23:03 GraphicsHunter

@tianfanzhu, no, I did not solve this problem. I have other tasks, but as soon as I have time, I will study the code caffe

Maxfashko avatar Mar 10 '18 12:03 Maxfashko

@Maxfashko Can you update your kernel functions for SoftmaxLayer here?

linux-devil avatar May 10 '18 10:05 linux-devil

@linux-devil, I hope I will help more than the author of this repository

__global__  void kernelSoftmax( float* x, int channels, float* y)
{
	extern __shared__ float mem[];
    __shared__ float sum_value;

	float number = *(x + blockDim.x*blockIdx.x + threadIdx.x);
	float number_exp = __expf(number);

    atomicAdd(&sum_value, number_exp);
    __syncthreads();

	y[blockDim.x*blockIdx.x + threadIdx.x] = __fdiv_rd(number_exp, sum_value);

}

void cudaSoftmax(int n, int channels,  float* x, float*y)
{
	kernelSoftmax<<< (n/channels), channels, channels*sizeof(float)>>>( x, channels, y);
	cudaDeviceSynchronize();
}

Maxfashko avatar May 10 '18 11:05 Maxfashko

@Maxfashko how did you get label name from label.txt? Thanks

quocbh avatar May 19 '18 17:05 quocbh

@quocbh https://github.com/Maxfashko/NV_TRT_SSD enjoy yourself

Maxfashko avatar May 20 '18 02:05 Maxfashko

In void cudaSoftmax(int n, int channels, float* x, float*y) , what is n? how you calculate n?

PiyalGeorge avatar Nov 14 '18 09:11 PiyalGeorge

@linux-devil, I hope I will help more than the author of this repository

__global__  void kernelSoftmax( float* x, int channels, float* y)
{
	extern __shared__ float mem[];
    __shared__ float sum_value;

	float number = *(x + blockDim.x*blockIdx.x + threadIdx.x);
	float number_exp = __expf(number);

    atomicAdd(&sum_value, number_exp);
    __syncthreads();

	y[blockDim.x*blockIdx.x + threadIdx.x] = __fdiv_rd(number_exp, sum_value);

}

void cudaSoftmax(int n, int channels,  float* x, float*y)
{
	kernelSoftmax<<< (n/channels), channels, channels*sizeof(float)>>>( x, channels, y);
	cudaDeviceSynchronize();
}

should add one line after "shared float sum_value;"

__shared__ float sum_value; sum_value=0;

liuchang8am avatar May 28 '19 08:05 liuchang8am