MONAI
MONAI copied to clipboard
updates to Gaussian map for sliding window inference
This updates the calculation of Gaussian map (weights) during sliding window inference with "gaussian"
Current version had multiple small issues
- it computed Guassian weight map (image) via nn.Conv1d sequence with an empty image (with a single 1 in the middle). We don't need to run any convolutions, it's much simpler to directly calculate the Gaussian map (it's also faster and takes less memory)
- For patch_sizes of even size (e.g. 128x128) it centered Gaussian on patch_size//2 which is 0.5 pixel off-center (I'm not sure why we did it.
- Finally the Guassian 1d convolutions were done approximately (with 'erf' internal approximation and truncated to sigma=4). I'm not sure why we need any approximations here at all, it's trivial to compute the Gaussian weight map directly
Types of changes
- [x] Non-breaking change (fix or new feature that would not break existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running
./runtests.sh -f -u --net --coverage
. - [ ] Quick tests passed locally by running
./runtests.sh --quick --unittests --disttests
. - [ ] In-line docstrings updated.
- [ ] Documentation updated, tested
make html
command in thedocs/
folder.
import torch
import time
from monai.networks.layers import GaussianFilter
def old_version(patch_size, approximate=True):
sigmas = [1.0, 2.0, 3.0]
device = 'cuda:0'
center_coords = [i // 2 for i in patch_size]
importance_map = torch.zeros(patch_size, device=device)
importance_map[tuple(center_coords)] = 1
if not approximate:
pt_gaussian = GaussianFilter(len(patch_size), sigmas, approx='sampled', truncated=10000).to(device=device, dtype=torch.float)
else:
pt_gaussian = GaussianFilter(len(patch_size), sigmas).to(device=device, dtype=torch.float)
importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0))
importance_map = importance_map.squeeze(0).squeeze(0)
importance_map = importance_map / torch.max(importance_map)
importance_map = importance_map.float()
return importance_map
def get_map_cpu(patch_size, dtype=torch.double):
sigmas = [1.0, 2.0, 3.0]
device = 'cuda:0'
importance_map = 0
for i in range(len(patch_size)):
x = torch.arange(start=-(patch_size[i] - 1) / 2.0, end=(patch_size[i] - 1) / 2.0 + 1, dtype=dtype)
x = torch.exp(x**2 / (-2 * sigmas[i] ** 2)) # 1D gaussian
importance_map = importance_map.unsqueeze(-1) * x[(None,) * i] if i > 0 else x
importance_map = importance_map.to(device=device, dtype=torch.float)
def get_map_cpu2(patch_size, dtype=torch.double):
sigmas = [1.0, 2.0, 3.0]
device = 'cuda:0'
importance_map = 0
for i in range(len(patch_size)):
x = torch.arange(start=-(patch_size[i] - 1) / 2.0, end=(patch_size[i] - 1) / 2.0 + 1, dtype=dtype)
x = x**2 / (-2 * sigmas[i] ** 2) # 1D
# x = torch.exp(x**2 / (-2 * sigmas[i] ** 2)) # 1D gaussian
importance_map = importance_map.unsqueeze(-1) * x[(None,) * i] if i > 0 else x
importance_map = torch.exp(importance_map)
importance_map = importance_map.to(device=device, dtype=torch.float)
def get_map_gpu(patch_size):
sigmas = [1.0, 2.0, 3.0]
device = 'cuda:0'
importance_map = 0
for i in range(len(patch_size)):
x = torch.arange(start=-(patch_size[i] - 1) / 2.0, end=(patch_size[i] - 1) / 2.0 + 1, dtype=torch.float, device=device)
x = torch.exp(x**2 / (-2 * sigmas[i] ** 2)) # 1D gaussian
importance_map = importance_map.unsqueeze(-1) * x[(None,) * i] if i > 0 else x
# importance_map = importance_map.to(device=device, dtype=torch.float)
patch_size = [192, 192, 192]
n_iter = 10
# run once, pre-cache
x = old_version(patch_size)
x = get_map_cpu(patch_size)
x = get_map_cpu(patch_size, dtype=torch.float)
x = get_map_cpu2(patch_size)
x = get_map_cpu2(patch_size, dtype=torch.float)
x = get_map_gpu(patch_size)
tic = time.time()
for _ in range(n_iter): x = old_version(patch_size)
print('old way time', time.time()-tic, 'sec')
tic = time.time()
for _ in range(n_iter): x = get_map_cpu(patch_size)
print('new way time (double)', time.time()-tic, 'sec')
tic = time.time()
for _ in range(n_iter): x = get_map_cpu(patch_size, dtype=torch.float)
print('new way time (float)', time.time()-tic, 'sec')
tic = time.time()
for _ in range(n_iter): x = get_map_cpu2(patch_size)
print('new way (exp final) time (double)', time.time()-tic, 'sec')
tic = time.time()
for _ in range(n_iter): x = get_map_cpu2(patch_size, dtype=torch.float)
print('new way (exp final) time (float)', time.time()-tic, 'sec')
tic = time.time()
for _ in range(n_iter):
x = get_map_gpu(patch_size)
torch.cuda.synchronize()
print('new way GPU time (float)', time.time()-tic, 'sec')
here is timing testing for different ways to compute the Gaussian map for 192x192x192 patch size
Output (on V100 16gb):
old way time 0.12150430679321289 sec
new way time (double) 0.09786248207092285 sec
new way time (float) 0.04155373573303223 sec
new way (exp final) time (double) 0.17283344268798828 sec
new way (exp final) time (float) 0.09635710716247559 sec
new way GPU time (float) 0.0036644935607910156 sec
All new variants are faster with less memory overhead then the old version. The old version is also only appriximate (and if doing correctly runs out of memory). Previously I calculated this map on CPU in double (for more accuracy) and moved to GPU at the end (0.0097sec), but if running on GPU (in float) seems much faster (0.0036sec). I've updated the code.
plz approve
For what it's worth I have this code for defining a gaussian kernel of arbitrary size:
def gaussian_kernel(*shape, device, dtype, sigma=2.0):
dims = [torch.arange(-s // 2 + s % 2, s - s // 2).to(device, dtype) for s in shape]
grid = torch.stack(torch.meshgrid(*dims, indexing="ij"))
return (grid**2).sum(0).div_(-2 * sigma**2).exp_()
I think this is the standard definition though I don't think gaussian_1d
produces an equivalent result.
For what it's worth I have this code for defining a gaussian kernel of arbitrary size:
def gaussian_kernel(*shape, device, dtype, sigma=2.0): dims = [torch.arange(-s // 2 + s % 2, s - s // 2).to(device, dtype) for s in shape] grid = torch.stack(torch.meshgrid(*dims, indexing="ij")) return (grid**2).sum(0).div_(-2 * sigma**2).exp_()
I think this is the standard definition though I don't think
gaussian_1d
produces an equivalent result.
yeah, your version looks good too. But it doesn't handle different sigma (per dimension). With meshgrid you're allocating 3xDxWxH arrays (in case of 3D patch ), and compute exp over this 3D volume (DxWxH). In my version (separable) it only allocates DxWxH gpu mem, and computes torch.exp 3x(D+W+H) times, so it's a bit faster and takes less memory. Otherwise it's the same.
I hadn't thought of different sigmas, but if the last line is (grid**2).div_(-2 * sigma**2).sum(0).exp_()
then sigma
can be a tensor of values instead to handle that case.
/build
/build