Stratified-Transformer
Stratified-Transformer copied to clipboard
Function understand:get_indice_pairs(p2v_map, counts, new_p2v_map, new_counts, downsample_idx, batch, xyz, window_size, i)
Your work is great! But I have some questions about the code, can you explain in detail what this function does? Especially about its input and output, and explain the logic and why? Thank you so much!!
The function is here:
def get_indice_pairs(p2v_map, counts, new_p2v_map, new_counts, downsample_idx, batch, xyz, window_size, i):
# p2v_map: [n, k]
# counts: [n, ]
n, k = p2v_map.shape
mask = torch.arange(k).unsqueeze(0).cuda() < counts.unsqueeze(-1) #[n, k]
mask_mat = (mask.unsqueeze(-1) & mask.unsqueeze(-2)) #[n, k, k]
index_0 = p2v_map.unsqueeze(-1).expand(-1, -1, k)[mask_mat] #[M, ]
index_1 = p2v_map.unsqueeze(1).expand(-1, k, -1)[mask_mat] #[M, ]
downsample_mask = torch.zeros_like(batch).bool() #[N, ]
downsample_mask[downsample_idx.long()] = True
downsample_mask = downsample_mask[new_p2v_map] #[n, k]
n, k = new_p2v_map.shape
mask = torch.arange(k).unsqueeze(0).cuda() < new_counts.unsqueeze(-1) #[n, k]
downsample_mask = downsample_mask & mask
mask_mat = (mask.unsqueeze(-1) & downsample_mask.unsqueeze(-2)) #[n, k, k]
xyz_min = xyz.min(0)[0]
if i % 2 == 0:
window_coord = (xyz[new_p2v_map] - xyz_min) // window_size #[n, k, 3]
else:
window_coord = (xyz[new_p2v_map] + 1/2*window_size - xyz_min) // window_size #[n, k, 3]
mask_mat_prev = (window_coord.unsqueeze(2) != window_coord.unsqueeze(1)).any(-1) #[n, k, k]
mask_mat = mask_mat & mask_mat_prev #[n, k, k]
new_index_0 = new_p2v_map.unsqueeze(-1).expand(-1, -1, k)[mask_mat] #[M, ]
new_index_1 = new_p2v_map.unsqueeze(1).expand(-1, k, -1)[mask_mat] #[M, ]
index_0 = torch.cat([index_0, new_index_0], 0)
index_1 = torch.cat([index_1, new_index_1], 0)
return index_0, index_1