Image-Adaptive-3DLUT icon indicating copy to clipboard operation
Image-Adaptive-3DLUT copied to clipboard

CPU implementation of trilinear only supports batch_size == 1

Open kamo262 opened this issue 2 years ago • 1 comments

I noticed the CPU implementation of trilinear forward and backward functions only support batch_size == 1. When we use the functions with batch_size > 2, the first sample is only computed.

I have to fix the functions as the following to process multiple samples in a batch.

void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels, const int batch)
{
    const int output_size = height * width;

    for (int batch_index = 0; batch_index < batch; ++batch_index) {
        const int batch_start_index = batch_index * output_size * channels;
        for (int index = 0; index < output_size; ++index)
        {
            float r = image[batch_start_index + index];
            float g = image[batch_start_index + index + width * height];
            float b = image[batch_start_index + index + width * height * 2];

            int r_id = floor(r / binsize);
            int g_id = floor(g / binsize);
            int b_id = floor(b / binsize);

            float r_d = fmod(r,binsize) / binsize;
            float g_d = fmod(g,binsize) / binsize;
            float b_d = fmod(b,binsize) / binsize;

            int id000 = r_id + g_id * dim + b_id * dim * dim;
            int id100 = r_id + 1 + g_id * dim + b_id * dim * dim;
            int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim;
            int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim;
            int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim;
            int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim;
            int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
            int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim;

            float w000 = (1-r_d)*(1-g_d)*(1-b_d);
            float w100 = r_d*(1-g_d)*(1-b_d);
            float w010 = (1-r_d)*g_d*(1-b_d);
            float w110 = r_d*g_d*(1-b_d);
            float w001 = (1-r_d)*(1-g_d)*b_d;
            float w101 = r_d*(1-g_d)*b_d;
            float w011 = (1-r_d)*g_d*b_d;
            float w111 = r_d*g_d*b_d;

            output[batch_start_index + index] =
                w000 * lut[id000] + w100 * lut[id100] + 
                w010 * lut[id010] + w110 * lut[id110] + 
                w001 * lut[id001] + w101 * lut[id101] + 
                w011 * lut[id011] + w111 * lut[id111];

            output[batch_start_index + index + width * height] =
                w000 * lut[id000 + shift] + w100 * lut[id100 + shift] + 
                w010 * lut[id010 + shift] + w110 * lut[id110 + shift] + 
                w001 * lut[id001 + shift] + w101 * lut[id101 + shift] + 
                w011 * lut[id011 + shift] + w111 * lut[id111 + shift];

            output[batch_start_index + index + width * height * 2] =
                w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] + 
                w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] + 
                w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] + 
                w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2];
        }
    }
}

kamo262 avatar Feb 24 '22 13:02 kamo262

Hi, thanks for sharing this code.

HuiZeng avatar Feb 28 '22 11:02 HuiZeng