torch_efficient_distloss
torch_efficient_distloss copied to clipboard
add an einsum_distloss implementation
Updates: Provide a python-interface implementation for distortion loss, no cuda kernels needed.
def einsum_distloss(w, m, interval):
'''
Einsum realization of distortion loss.
There are B rays each with N sampled points.
w: Float tensor in shape [B,N]. Volume rendering weights of each point.
m: Float tensor in shape [N]. Midpoint distance to camera of each point.
interval: Scalar or float tensor in shape [B,N]. The query interval of each point.
Note:
The first term of distortion could be experssed as `(w @ mm @ w.T).diagonal()`, which
could be further accelerated by einsum function `torch.einsum('bq, qp, bp->b', w, mm, w)`
'''
mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs() # [N,N]
loss = torch.einsum('bq, qp, bp->b', w, mm, w)
loss += (w*w*interval).sum(-1)/3.
return loss.mean()
- Peak GPU memory (MB)
# of pts N32 64 128 256 384 512 1024 original_distloss102 396 1560 6192 OOM OOM OOM eff_distloss_native12 24 48 96 144 192 384 eff_distloss14 28 56 112 168 224 448 flatten_eff_distloss13 26 52 104 156 208 416 einsum_distloss9 18 36 72 109 145 292 - Run time accumulated over 100 runs (sec)
# of pts N32 64 128 256 384 512 1024 original_distloss0.4 0.6 3.3 14.9 OOM OOM OOM eff_distloss_native0.2 0.2 0.2 0.4 0.4 0.5 0.8 eff_distloss0.2 0.2 0.2 0.3 0.5 0.6 0.9 flatten_eff_distloss0.2 0.2 0.2 0.3 0.5 0.5 0.8 einsum_distloss0.1 0.1 0.1 0.2 0.3 0.4 0.7
@Spark001: in your example, you are assuming that m is of size [N] and not [B, N], so identical for each ray. The allocation of mm is still O(N^2), and if you use [B, N] then you're back to the OOM of the original implementation, albeit with more samples. A workaround might be to have a custom CUDA kernel so that mm is never explicitly allocated but evaluated by each CUDA thread.
@bchretien Yes, you are right. In my implementation I assume the sampling interval is identical for all rays. Because the uniform sampling is generally used for explicit voxel-based NERF methods, e.g. DVGO and Plenoxels.