3DGNN_pytorch
3DGNN_pytorch copied to clipboard
About the part of the `batch_index` in 'models.py'
At first, thanks for your code...
https://github.com/yanx27/3DGNN_pytorch/blob/b5e0188e56f926cff9b2c9a68bedf42cb3b42d2f/models.py#L411
Can we use convolution(e.g. 1x1) operations to achieve this part?
I have a try like this:
self.g_rnn_conv = nn.Sequential(
nn.Conv2d(2048 * self.k, 2048 * self.k, 1),
nn.BatchNorm2d(2048 * self.k),
nn.PReLU()
)
...
# loop over timestamps to unroll
for i in range(self.gnn_iternum):
# do this for every sample in batch, not nice, but I don't know
# how to use index_select batchwise
# fetch features from nearest neighbors
# N H W K*C
h = h.view(N * (H * W), C) # NHW C
neighbor_f = torch.index_select(h, 0, knn).view(N, H, W, K * C)
neighbor_f = neighbor_f.permute(0, 3, 1, 2)
neighbor_f = self.g_rnn_conv(neighbor_f)
neighbor_f = neighbor_f.permute(0, 2, 3, 1).contiguous() # N H W KC
neighbor_f = neighbor_f.view(N, H * W, K, C)
# aggregate and iterate messages in m, keep original CNN features h for later
m = torch.mean(neighbor_f, dim=2)
h = h.view(N, (H * W), C)
# concatenate current state with messages
concat = torch.cat((h, m), 2) # N HW 2C
# get new features by running MLP q and activation function
h = self.q_rnn_actf(self.q_rnn_layer(concat)) # N HW C
and another try:
self.g_rnn_conv = nn.Sequential(
nn.Conv2d(2048 * self.k, 2048 * self.k, 1),
nn.BatchNorm2d(2048 * self.k),
nn.PReLU()
)
self.q_rnn_conv = nn.Sequential(
nn.Conv2d(4096, 2048, 1),
nn.BatchNorm2d(2048),
nn.PReLU()
)
...
# get k nearest neighbors
knn = self.__get_knn_indices(proj_3d) # N HW K
knn = knn.view(N * H * W * K).long() # NHWK
# prepare CNN encoded features for RNN
h = cnn_encoder_output # N C H W
# 调整维度之后, 一般需要在contiguous后才能用view
h = h.permute(0, 2, 3, 1).contiguous() # N H W C
# loop over timestamps to unroll
for i in range(self.gnn_iternum):
# do this for every sample in batch, not nice, but I don't know
# how to use index_select batchwise
# fetch features from nearest neighbors
# N H W K*C
h = h.view(N * (H * W), C) # NHW C
neighbor_f = torch.index_select(h, 0, knn).view(N, H, W, K * C)
neighbor_f = neighbor_f.permute(0, 3, 1, 2) # N KC H W
neighbor_f = self.g_rnn_conv(neighbor_f)
neighbor_f = neighbor_f.permute(0, 2, 3, 1).contiguous() # N H W KC
neighbor_f = neighbor_f.view(N, H * W, K, C)
# aggregate and iterate messages in m, keep original CNN features h for later
m = torch.mean(neighbor_f, dim=2)
h = h.view(N, (H * W), C)
# concatenate current state with messages
concat = torch.cat((h, m), 2).view(N, H, W, 2 * C) # N HW 2C
concat = concat.permute(0, 3, 1, 2)
# get new features by running MLP q and activation function
h = self.q_rnn_conv(concat) # N, C, H, W
h = h.permute(0, 2, 3, 1).contiguous() # N H W C
# format RNN activations back to image, concatenate original CNN embedding, return
h = h.view(N, H, W, C).permute(0, 3, 1, 2).contiguous() # N C H W
output = self.output_conv(
torch.cat((cnn_encoder_output, h), 1)) # N 2C H W
return output