Pytorch-Correlation-extension
Pytorch-Correlation-extension copied to clipboard
how to visualise the oupt
trafficstars
Like you would visualize a 4D tensor ... So there is no definitive answer, my best guess woud be to choose a particular shift position and use a standard colormap to see the correlation map at this particular shift.
import torch
from spatial_correlation_sampler import SpatialCorrelationSampler,
import matplotlib.pyplot as plt
device = "cuda"
batch_size = 1
channel = 1
H = 10
W = 10
dtype = torch.float32
input1 = torch.randint(1, 4, (batch_size, channel, H, W), dtype=dtype, device=device, requires_grad=True)
input2 = torch.randint_like(input1, 1, 4).requires_grad_(True)
out = spatial_correlation_sample(input1,
input2,
patch_size=3)
# out is of shape [patch_size, patch_size, H, W]
no_shift = out[1, 1].cpu().numpy()
plt.imshow(no_shift)
plt.show()
shift_11 = out[2, 2].cpu().numy()
plt.imshow(shift_11)
plt.show()
# Ans so on...