3DGNN_pytorch icon indicating copy to clipboard operation
3DGNN_pytorch copied to clipboard

About the part of the `batch_index` in 'models.py'

Open lartpang opened this issue 5 years ago • 0 comments

At first, thanks for your code...

https://github.com/yanx27/3DGNN_pytorch/blob/b5e0188e56f926cff9b2c9a68bedf42cb3b42d2f/models.py#L411

Can we use convolution(e.g. 1x1) operations to achieve this part?

I have a try like this:

        self.g_rnn_conv = nn.Sequential(
            nn.Conv2d(2048 * self.k, 2048 * self.k, 1),
            nn.BatchNorm2d(2048 * self.k),
            nn.PReLU()
        )

        ...

        # loop over timestamps to unroll
        for i in range(self.gnn_iternum):
            # do this for every sample in batch, not nice, but I don't know
            # how to use index_select batchwise
            # fetch features from nearest neighbors
            # N H W K*C
            h = h.view(N * (H * W), C)  # NHW C
            neighbor_f = torch.index_select(h, 0, knn).view(N, H, W, K * C)
            neighbor_f = neighbor_f.permute(0, 3, 1, 2)
            neighbor_f = self.g_rnn_conv(neighbor_f)
            neighbor_f = neighbor_f.permute(0, 2, 3, 1).contiguous()  # N H W KC
            neighbor_f = neighbor_f.view(N, H * W, K, C)

            # aggregate and iterate messages in m, keep original CNN features h for later
            m = torch.mean(neighbor_f, dim=2)
            h = h.view(N, (H * W), C)

            # concatenate current state with messages
            concat = torch.cat((h, m), 2)  # N HW 2C

            # get new features by running MLP q and activation function
            h = self.q_rnn_actf(self.q_rnn_layer(concat))  # N HW C

and another try:

        self.g_rnn_conv = nn.Sequential(
            nn.Conv2d(2048 * self.k, 2048 * self.k, 1),
            nn.BatchNorm2d(2048 * self.k),
            nn.PReLU()
        )

        self.q_rnn_conv = nn.Sequential(
            nn.Conv2d(4096, 2048, 1),
            nn.BatchNorm2d(2048),
            nn.PReLU()
        )

...

        # get k nearest neighbors
        knn = self.__get_knn_indices(proj_3d)  # N HW K
        knn = knn.view(N * H * W * K).long()  # NHWK

        # prepare CNN encoded features for RNN
        h = cnn_encoder_output  # N C H W
        # 调整维度之后, 一般需要在contiguous后才能用view
        h = h.permute(0, 2, 3, 1).contiguous()  # N H W C

        # loop over timestamps to unroll
        for i in range(self.gnn_iternum):
            # do this for every sample in batch, not nice, but I don't know
            # how to use index_select batchwise
            # fetch features from nearest neighbors
            # N H W K*C
            h = h.view(N * (H * W), C)  # NHW C
            neighbor_f = torch.index_select(h, 0, knn).view(N, H, W, K * C)
            neighbor_f = neighbor_f.permute(0, 3, 1, 2)  # N KC H W
            neighbor_f = self.g_rnn_conv(neighbor_f)
            neighbor_f = neighbor_f.permute(0, 2, 3, 1).contiguous()  # N H W KC
            neighbor_f = neighbor_f.view(N, H * W, K, C)

            # aggregate and iterate messages in m, keep original CNN features h for later
            m = torch.mean(neighbor_f, dim=2)
            h = h.view(N, (H * W), C)

            # concatenate current state with messages
            concat = torch.cat((h, m), 2).view(N, H, W, 2 * C)  # N HW 2C
            concat = concat.permute(0, 3, 1, 2)

            # get new features by running MLP q and activation function
            h = self.q_rnn_conv(concat)  # N, C, H, W
            h = h.permute(0, 2, 3, 1).contiguous()  # N H W C

        # format RNN activations back to image, concatenate original CNN embedding, return
        h = h.view(N, H, W, C).permute(0, 3, 1, 2).contiguous()  # N C H W
        output = self.output_conv(
            torch.cat((cnn_encoder_output, h), 1))  # N 2C H W

        return output

this is my code...

lartpang avatar Apr 01 '19 12:04 lartpang