Selective-Stereo
Selective-Stereo copied to clipboard
PyTorch Correlation Block
I see that the PyTorch Correlation block is the same as RAFT-Stereo but the way coordinate grids are defined is different and hence, during inference, if I use the PyTorch Correlation Block, it gives an error
I rewrote the PyTorch Correlation Block as follows:
class PytorchAlternateCorrBlock1D:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
self.fmap1 = fmap1
self.fmap2 = fmap2
def corr(self, fmap1, fmap2, coords):
B, D, H, W = fmap2.shape
# map grid coordinates to [-1,1]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
print(grid.shape)
output_corr = []
for grid_slice in grid.unbind(3):
fmapw_mini = F.grid_sample(fmap2, grid_slice, align_corners=True)
corr = torch.sum(fmapw_mini * fmap1, dim=1)
output_corr.append(corr)
corr = torch.stack(output_corr, dim=1).permute(0,2,3,1)
return corr / torch.sqrt(torch.tensor(D).float())
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
fmap1 = self.fmap1
fmap2 = self.fmap2
out_pyramid = []
for i in range(self.num_levels):
# dx = torch.zeros(1).repeat()
# dy = torch.linspace(-r, r, 2*r+1)
# delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
# centroid_lvl = coords.reshape(batch, h1, w1, 1, 1).clone()
# centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i
# coords_lvl = centroid_lvl + dy.to(coords.device)
dx = torch.linspace(-r, r, 2*r + 1).repeat(1, 1, 1, 1).to(coords.device)
#dx = dx.view(2 * r + 1, 1).to(coords.device)
x0 = dx + coords.reshape(batch, h1, w1, 1).clone() / 2**i
y0 = torch.zeros_like(x0)
coords_lvl = torch.stack([x0,y0], dim=-1)
corr = self.corr(fmap1, fmap2, coords_lvl)
fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2])
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
But the output is not the same as the CorrBlock1D. Could you please check if the above Correlation Block is correct?
I think if you want to use this block, you should use the same block during training. There may by indexing problems due to different indexing codes.