softmax-splatting
softmax-splatting copied to clipboard
Question of technical implemention details on Z^max (Equation 3)
Hi Simon,
I'm trying to re-produce your recent paper on splatting-based synthesis for video frame interpolation and it was really nice work that inspires me a lot. But I'm stuck at implementing numerically stable softsplat you mentioned in Section 3, where you said that "warp Z0 to time t as Zmax ... this step is and need not be differentiable ...". I'd be appreciated if you could further clarify the following two questions:
- how to implement the necessary "backward" function of
torch.autograd.Functionto calculate Zmax in training process. I've implemented the following snippet to calculate Zmax and it works well,
class softsplat_zmax_func(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, tenIn, tenFlow):
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) # max weight
if tenIn.is_cuda == True:
cuda_launch(cuda_kernel('zmax_out', '''
__device__ __forceinline__ float atomicMinFloat(float* addr, float value) {
float old;
old = !signbit(value) ? __int_as_float(atomicMin((int*)addr, __float_as_int(value))) :
__uint_as_float(atomicMax((unsigned int*)addr, __float_as_uint(value)));
return old;
}
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = !signbit(value) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
extern "C" __global__ void __launch_bounds__(512) zmax_out(
const int n,
const {{type}}* __restrict__ tenIn, // Z input only, B 1 H W
const {{type}}* __restrict__ tenFlow,
{{type}}* __restrict__ tenOut // Z max output
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
const int intX = ( intIndex ) % SIZE_3(tenOut);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
/*
for (int i = intNorthwestX - 1; i < intNorthwestX + 3; i++)
{
for (int j = intNorthwestY - 1; j < intNorthwestY + 3; j++)
{
if ((i >= 0) && (i < SIZE_3(tenOut)) && (j >= 0) && (j < SIZE_2(tenOut))) {
atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, j, i)], fltIn);
}
}
}
*/
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn);
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn);
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn);
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn);
}
} }
''', {
'tenIn': tenIn,
'tenFlow': tenFlow,
'tenOut': tenOut
}))(
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
elif tenIn.is_cuda != True:
assert (False)
# end
self.save_for_backward(tenIn, tenFlow)
return tenOut
# end
along with some modification on the softsplat function
...
elif strMode.split('-')[0] == 'soft':
tenMetricMax = softsplat_zmax_func.apply(tenMetric, tenFlow)
tenMetric = torch.exp(tenMetric - tenMetricMax)
# tenMetric = torch.exp(tenMetric)
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
...
it's fine for inference but and I can't figure out how to design the backward function for softsplat_zmax_func since it requires some gradient so as not to mess up the training.
- I notice that atomic max of cupy does not support float operation, while I notice you said that "This can be efficiently computed in parallel using an atomic max". Could you please share with us how you handled this?
Thanks in advance!