Selective-Stereo icon indicating copy to clipboard operation
Selective-Stereo copied to clipboard

PyTorch Correlation Block

Open naik24 opened this issue 10 months ago • 2 comments

I see that the PyTorch Correlation block is the same as RAFT-Stereo but the way coordinate grids are defined is different and hence, during inference, if I use the PyTorch Correlation Block, it gives an error

naik24 avatar Feb 03 '25 21:02 naik24

I rewrote the PyTorch Correlation Block as follows:

class PytorchAlternateCorrBlock1D:
    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
        self.num_levels = num_levels
        self.radius = radius
        self.corr_pyramid = []
        self.fmap1 = fmap1
        self.fmap2 = fmap2

    def corr(self, fmap1, fmap2, coords):
        B, D, H, W = fmap2.shape
        # map grid coordinates to [-1,1]
        xgrid, ygrid = coords.split([1,1], dim=-1)
        xgrid = 2*xgrid/(W-1) - 1
        ygrid = 2*ygrid/(H-1) - 1

        grid = torch.cat([xgrid, ygrid], dim=-1)
        print(grid.shape)
        output_corr = []
        for grid_slice in grid.unbind(3):
            fmapw_mini = F.grid_sample(fmap2, grid_slice, align_corners=True)
            corr = torch.sum(fmapw_mini * fmap1, dim=1)
            output_corr.append(corr)
        corr = torch.stack(output_corr, dim=1).permute(0,2,3,1)

        return corr / torch.sqrt(torch.tensor(D).float())

    def __call__(self, coords):
        r = self.radius
        coords = coords.permute(0, 2, 3, 1)
        batch, h1, w1, _ = coords.shape
        fmap1 = self.fmap1
        fmap2 = self.fmap2
        out_pyramid = []
        for i in range(self.num_levels):
            # dx = torch.zeros(1).repeat()
            # dy = torch.linspace(-r, r, 2*r+1)
            # delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
            # centroid_lvl = coords.reshape(batch, h1, w1, 1, 1).clone()
            # centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i
            # coords_lvl = centroid_lvl + dy.to(coords.device)

            dx = torch.linspace(-r, r, 2*r + 1).repeat(1, 1, 1, 1).to(coords.device)
            #dx = dx.view(2 * r + 1, 1).to(coords.device)
            x0 = dx + coords.reshape(batch, h1, w1, 1).clone() / 2**i
            y0 = torch.zeros_like(x0)
            coords_lvl = torch.stack([x0,y0], dim=-1)
            corr = self.corr(fmap1, fmap2, coords_lvl)
            fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2])
            out_pyramid.append(corr)
        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()

But the output is not the same as the CorrBlock1D. Could you please check if the above Correlation Block is correct?

Image

naik24 avatar Feb 04 '25 07:02 naik24

I think if you want to use this block, you should use the same block during training. There may by indexing problems due to different indexing codes.

Windsrain avatar Feb 17 '25 06:02 Windsrain