Implimentation questions
Hi, thank you for your great work.
While studying the implementations, I noticed some differences in the code compared to the paper, so I would like to ask some questions. Firstly, in the main paper of Section 3.2, it appears that the key and value for near-view cross attention are computed for each sampled point in 3D space because the notation $[f_1^s,\cdots,f_{F-1}^s]$ implies that the number of concatenated features is F-1. However, in the implementation, the concatenated feature actually targets all features of all sampled points. Is my understanding correct? Next, I am unable to locate where the Ray self-attention is being computed. In the class FeatureAggregator(nn.Module), the weighted summation is performed immediately after the near-view cross-attention. Can you please inform me where the Ray self-attention is being computed?
class FeatureAggregator(nn.Module):
...
query = rearrange(query, "bfn 1 s c -> bfn s c")
context = rearrange(context, "bfn k s c -> bfn (k s) c")
t_emb = repeat(t_emb, "bf c -> (bf n) c", n=h * w)
out = self.T1(query, t=t_emb, context=context, attention_mask=mask) # Near-view cross attention
mask = (
reduce(
mask,
"bfn k s -> bfn s 1",
"sum",
)
> 0
)
weight = self.t1_point_fc(out) # Weighted summation
weight = weight.masked_fill(~mask, torch.finfo(weight.dtype).min)
weight = F.softmax(weight, dim=1)
t1 = torch.sum(out * weight, dim=1)
images = rearrange(t1, "(bf h w) c -> bf c h w", bf=bf, h=h, w=w)
I have the same question about the Ray self-attention, wish to get some answers