MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

updates to Gaussian map for sliding window inference

Open myron opened this issue 2 years ago • 4 comments

This updates the calculation of Gaussian map (weights) during sliding window inference with "gaussian"

Current version had multiple small issues

  • it computed Guassian weight map (image) via nn.Conv1d sequence with an empty image (with a single 1 in the middle). We don't need to run any convolutions, it's much simpler to directly calculate the Gaussian map (it's also faster and takes less memory)
  • For patch_sizes of even size (e.g. 128x128) it centered Gaussian on patch_size//2 which is 0.5 pixel off-center (I'm not sure why we did it.
  • Finally the Guassian 1d convolutions were done approximately (with 'erf' internal approximation and truncated to sigma=4). I'm not sure why we need any approximations here at all, it's trivial to compute the Gaussian weight map directly

Types of changes

  • [x] Non-breaking change (fix or new feature that would not break existing functionality).
  • [ ] Breaking change (fix or new feature that would cause existing functionality to change).
  • [ ] New tests added to cover the changes.
  • [ ] Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • [ ] Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • [ ] In-line docstrings updated.
  • [ ] Documentation updated, tested make html command in the docs/ folder.

myron avatar Oct 09 '22 01:10 myron

import torch
import time
from monai.networks.layers import GaussianFilter


def old_version(patch_size, approximate=True):
    sigmas = [1.0, 2.0, 3.0]
    device = 'cuda:0'
    center_coords = [i // 2 for i in patch_size]

    importance_map = torch.zeros(patch_size, device=device)
    importance_map[tuple(center_coords)] = 1
    if not approximate:
        pt_gaussian = GaussianFilter(len(patch_size), sigmas, approx='sampled', truncated=10000).to(device=device, dtype=torch.float)
    else:
        pt_gaussian = GaussianFilter(len(patch_size), sigmas).to(device=device, dtype=torch.float)

    importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0))
    importance_map = importance_map.squeeze(0).squeeze(0)
    importance_map = importance_map / torch.max(importance_map)
    importance_map = importance_map.float()

    return importance_map

        
def get_map_cpu(patch_size, dtype=torch.double):
    sigmas = [1.0, 2.0, 3.0]
    device = 'cuda:0'

    importance_map = 0
    for i in range(len(patch_size)):
        x = torch.arange(start=-(patch_size[i] - 1) / 2.0, end=(patch_size[i] - 1) / 2.0 + 1, dtype=dtype)
        x = torch.exp(x**2 / (-2 * sigmas[i] ** 2))  # 1D gaussian
        importance_map = importance_map.unsqueeze(-1) * x[(None,) * i] if i > 0 else x
    importance_map = importance_map.to(device=device, dtype=torch.float)

def get_map_cpu2(patch_size, dtype=torch.double):
    sigmas = [1.0, 2.0, 3.0]
    device = 'cuda:0'

    importance_map = 0
    for i in range(len(patch_size)):
        x = torch.arange(start=-(patch_size[i] - 1) / 2.0, end=(patch_size[i] - 1) / 2.0 + 1, dtype=dtype)
        x = x**2 / (-2 * sigmas[i] ** 2) # 1D 
        # x = torch.exp(x**2 / (-2 * sigmas[i] ** 2))  # 1D gaussian
        importance_map = importance_map.unsqueeze(-1) * x[(None,) * i] if i > 0 else x
    importance_map = torch.exp(importance_map)
    importance_map = importance_map.to(device=device, dtype=torch.float)

def get_map_gpu(patch_size):
    sigmas = [1.0, 2.0, 3.0]
    device = 'cuda:0'

    importance_map = 0
    for i in range(len(patch_size)):
        x = torch.arange(start=-(patch_size[i] - 1) / 2.0, end=(patch_size[i] - 1) / 2.0 + 1, dtype=torch.float, device=device)
        x = torch.exp(x**2 / (-2 * sigmas[i] ** 2))  # 1D gaussian
        importance_map = importance_map.unsqueeze(-1) * x[(None,) * i] if i > 0 else x
    # importance_map = importance_map.to(device=device, dtype=torch.float)



patch_size = [192, 192, 192]
n_iter = 10

# run once, pre-cache
x = old_version(patch_size)
x = get_map_cpu(patch_size)
x = get_map_cpu(patch_size, dtype=torch.float)
x = get_map_cpu2(patch_size)
x = get_map_cpu2(patch_size, dtype=torch.float)
x = get_map_gpu(patch_size)

tic = time.time()
for _ in range(n_iter): x = old_version(patch_size)
print('old way time', time.time()-tic, 'sec')

tic = time.time()
for _ in range(n_iter): x = get_map_cpu(patch_size)
print('new way time (double)', time.time()-tic, 'sec')


tic = time.time()
for _ in range(n_iter): x = get_map_cpu(patch_size, dtype=torch.float)
print('new way time (float)', time.time()-tic, 'sec')

tic = time.time()
for _ in range(n_iter): x = get_map_cpu2(patch_size)
print('new way (exp final) time (double)', time.time()-tic, 'sec')


tic = time.time()
for _ in range(n_iter): x = get_map_cpu2(patch_size, dtype=torch.float)
print('new way (exp final) time (float)', time.time()-tic, 'sec')

tic = time.time()
for _ in range(n_iter): 
    x = get_map_gpu(patch_size)
    torch.cuda.synchronize()
print('new way GPU time (float)', time.time()-tic, 'sec')

here is timing testing for different ways to compute the Gaussian map for 192x192x192 patch size

Output (on V100 16gb):

old way time 0.12150430679321289 sec
new way time (double) 0.09786248207092285 sec
new way time (float) 0.04155373573303223 sec
new way (exp final) time (double) 0.17283344268798828 sec
new way (exp final) time (float) 0.09635710716247559 sec
new way GPU time (float) 0.0036644935607910156 sec

All new variants are faster with less memory overhead then the old version. The old version is also only appriximate (and if doing correctly runs out of memory). Previously I calculated this map on CPU in double (for more accuracy) and moved to GPU at the end (0.0097sec), but if running on GPU (in float) seems much faster (0.0036sec). I've updated the code.

plz approve

myron avatar Oct 09 '22 20:10 myron

For what it's worth I have this code for defining a gaussian kernel of arbitrary size:

def gaussian_kernel(*shape, device, dtype, sigma=2.0):
    dims = [torch.arange(-s // 2 + s % 2, s - s // 2).to(device, dtype) for s in shape]
    grid = torch.stack(torch.meshgrid(*dims, indexing="ij"))
    return (grid**2).sum(0).div_(-2 * sigma**2).exp_()

I think this is the standard definition though I don't think gaussian_1d produces an equivalent result.

ericspod avatar Oct 09 '22 21:10 ericspod

For what it's worth I have this code for defining a gaussian kernel of arbitrary size:

def gaussian_kernel(*shape, device, dtype, sigma=2.0):
    dims = [torch.arange(-s // 2 + s % 2, s - s // 2).to(device, dtype) for s in shape]
    grid = torch.stack(torch.meshgrid(*dims, indexing="ij"))
    return (grid**2).sum(0).div_(-2 * sigma**2).exp_()

I think this is the standard definition though I don't think gaussian_1d produces an equivalent result.

yeah, your version looks good too. But it doesn't handle different sigma (per dimension). With meshgrid you're allocating 3xDxWxH arrays (in case of 3D patch ), and compute exp over this 3D volume (DxWxH). In my version (separable) it only allocates DxWxH gpu mem, and computes torch.exp 3x(D+W+H) times, so it's a bit faster and takes less memory. Otherwise it's the same.

myron avatar Oct 09 '22 23:10 myron

I hadn't thought of different sigmas, but if the last line is (grid**2).div_(-2 * sigma**2).sum(0).exp_() then sigma can be a tensor of values instead to handle that case.

ericspod avatar Oct 09 '22 23:10 ericspod

/build

wyli avatar Oct 12 '22 08:10 wyli

/build

wyli avatar Oct 12 '22 08:10 wyli