softmax-splatting icon indicating copy to clipboard operation
softmax-splatting copied to clipboard

Question of technical implemention details on Z^max (Equation 3)

Open Justin62628 opened this issue 2 years ago • 0 comments

Hi Simon,

I'm trying to re-produce your recent paper on splatting-based synthesis for video frame interpolation and it was really nice work that inspires me a lot. But I'm stuck at implementing numerically stable softsplat you mentioned in Section 3, where you said that "warp Z0 to time t as Zmax ... this step is and need not be differentiable ...". I'd be appreciated if you could further clarify the following two questions:

  1. how to implement the necessary "backward" function of torch.autograd.Function to calculate Zmax in training process. I've implemented the following snippet to calculate Zmax and it works well,

class softsplat_zmax_func(torch.autograd.Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def forward(self, tenIn, tenFlow):
        tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])  # max weight

        if tenIn.is_cuda == True:
            cuda_launch(cuda_kernel('zmax_out', '''
            
                __device__ __forceinline__ float atomicMinFloat(float* addr, float value) {
                    float old;
                    old = !signbit(value) ? __int_as_float(atomicMin((int*)addr, __float_as_int(value))) :
                        __uint_as_float(atomicMax((unsigned int*)addr, __float_as_uint(value)));
                
                    return old;
                }
                
                __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
                    float old;
                    old = !signbit(value) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
                        __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
                
                    return old;
                }
            
                extern "C" __global__ void __launch_bounds__(512) zmax_out(
                    const int n,
                    const {{type}}* __restrict__ tenIn,  // Z input only, B 1 H W
                    const {{type}}* __restrict__ tenFlow,
                    {{type}}* __restrict__ tenOut  // Z max output
                ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
                    const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
                    const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut)                  ) % SIZE_1(tenOut);
                    const int intY = ( intIndex / SIZE_3(tenOut)                                   ) % SIZE_2(tenOut);
                    const int intX = ( intIndex                                                    ) % SIZE_3(tenOut);

                    assert(SIZE_1(tenFlow) == 2);

                    {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
                    {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);

                    if (isfinite(fltX) == false) { return; }
                    if (isfinite(fltY) == false) { return; }

                    {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);

                    int intNorthwestX = (int) (floor(fltX));
                    int intNorthwestY = (int) (floor(fltY));
                    int intNortheastX = intNorthwestX + 1;
                    int intNortheastY = intNorthwestY;
                    int intSouthwestX = intNorthwestX;
                    int intSouthwestY = intNorthwestY + 1;
                    int intSoutheastX = intNorthwestX + 1;
                    int intSoutheastY = intNorthwestY + 1;
                    
                    /*
                    for (int i = intNorthwestX - 1; i < intNorthwestX + 3; i++)
                    {
                        for (int j = intNorthwestY - 1; j < intNorthwestY + 3; j++)
                        {
                            if ((i >= 0) && (i < SIZE_3(tenOut)) && (j >= 0) && (j < SIZE_2(tenOut))) {
                                atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, j, i)], fltIn);
                            }
                        }
                    } 
                    */

                    
                    if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
                        atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn);
                    }

                    if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
                        atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn);
                    }

                    if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
                        atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn);
                    }

                    if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
                        atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn);
                    }
                    
                } }
            ''', {
                'tenIn': tenIn,
                'tenFlow': tenFlow,
                'tenOut': tenOut
            }))(
                grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
                block=tuple([512, 1, 1]),
                args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
                stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
            )

        elif tenIn.is_cuda != True:
            assert (False)

        # end

        self.save_for_backward(tenIn, tenFlow)

        return tenOut

    # end

along with some modification on the softsplat function

...
    elif strMode.split('-')[0] == 'soft':
        tenMetricMax = softsplat_zmax_func.apply(tenMetric, tenFlow)
        tenMetric = torch.exp(tenMetric - tenMetricMax)
        # tenMetric = torch.exp(tenMetric)
        tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
...

it's fine for inference but and I can't figure out how to design the backward function for softsplat_zmax_func since it requires some gradient so as not to mess up the training.

  1. I notice that atomic max of cupy does not support float operation, while I notice you said that "This can be efficiently computed in parallel using an atomic max". Could you please share with us how you handled this?

Thanks in advance!

Justin62628 avatar Aug 27 '23 09:08 Justin62628