torch_efficient_distloss icon indicating copy to clipboard operation
torch_efficient_distloss copied to clipboard

add an einsum_distloss implementation

Open Spark001 opened this issue 3 years ago • 4 comments
trafficstars

Spark001 avatar Jul 12 '22 12:07 Spark001

Updates: Provide a python-interface implementation for distortion loss, no cuda kernels needed.

def einsum_distloss(w, m, interval):
    '''
    Einsum realization of distortion loss.
    There are B rays each with N sampled points.
    w:        Float tensor in shape [B,N]. Volume rendering weights of each point.
    m:        Float tensor in shape [N].   Midpoint distance to camera of each point.
    interval: Scalar or float tensor in shape [B,N]. The query interval of each point.
    Note: 
    The first term of distortion could be experssed as `(w @ mm @ w.T).diagonal()`, which 
    could be further accelerated by einsum function `torch.einsum('bq, qp, bp->b', w, mm, w)`
    '''
    mm = (m.unsqueeze(-1) - m.unsqueeze(-2)).abs()  # [N,N]
    loss = torch.einsum('bq, qp, bp->b', w, mm, w)
    loss += (w*w*interval).sum(-1)/3.
    return loss.mean()

Spark001 avatar Jul 14 '22 09:07 Spark001

  • Peak GPU memory (MB)
    # of pts N 32 64 128 256 384 512 1024
    original_distloss 102 396 1560 6192 OOM OOM OOM
    eff_distloss_native 12 24 48 96 144 192 384
    eff_distloss 14 28 56 112 168 224 448
    flatten_eff_distloss 13 26 52 104 156 208 416
    einsum_distloss 9 18 36 72 109 145 292
  • Run time accumulated over 100 runs (sec)
    # of pts N 32 64 128 256 384 512 1024
    original_distloss 0.4 0.6 3.3 14.9 OOM OOM OOM
    eff_distloss_native 0.2 0.2 0.2 0.4 0.4 0.5 0.8
    eff_distloss 0.2 0.2 0.2 0.3 0.5 0.6 0.9
    flatten_eff_distloss 0.2 0.2 0.2 0.3 0.5 0.5 0.8
    einsum_distloss 0.1 0.1 0.1 0.2 0.3 0.4 0.7

Spark001 avatar Jul 14 '22 09:07 Spark001

@Spark001: in your example, you are assuming that m is of size [N] and not [B, N], so identical for each ray. The allocation of mm is still O(N^2), and if you use [B, N] then you're back to the OOM of the original implementation, albeit with more samples. A workaround might be to have a custom CUDA kernel so that mm is never explicitly allocated but evaluated by each CUDA thread.

bchretien avatar Jul 28 '22 18:07 bchretien

@bchretien Yes, you are right. In my implementation I assume the sampling interval is identical for all rays. Because the uniform sampling is generally used for explicit voxel-based NERF methods, e.g. DVGO and Plenoxels.

Spark001 avatar Aug 01 '22 08:08 Spark001