pytorch-mutual-information
pytorch-mutual-information copied to clipboard
Mutual Information in Pytorch
Update: Integrated into Kornia
pytorch-mutual-information
Batch computation of mutual information and histogram2d in Pytorch
This implementation uses kernel density estimation with a gaussian kernel to calculate histograms and joint histograms. We use a diagonal bandwidth matrix for the multivariate case, which allows us to decompose the multivariate kernel as the product of each univariate kernel. From wikipedia,
where the bandwith matrix
Example usage
Setup
device = 'cuda:0'
img1 = Image.open('grad1.jpg').convert('L')
img2 = Image.open('grad.jpg').convert('L')
img1 = transforms.ToTensor() (img1).unsqueeze(dim=0).to(device)
img2 = transforms.ToTensor() (img2).unsqueeze(dim=0).to(device)
# Pair of different images, pair of same images
input1 = torch.cat([img2, img2])
input2 = torch.cat([img1, img2])
B, C, H, W = input1.shape # shape: (2, 1, 300, 300)
Histogram usage:
hist = histogram(input1.view(B, H*W), torch.linspace(0,255,256), sigma)
Histogram 2D usage:
hist = histogram2d(input1.view(B, H*W), input2.view(B, H*W), torch.linspace(0,255,256), sigma)
Mutual Information (of images)
MI = MutualInformation(num_bins=256, sigma=0.4, normalize=True).to(device)
score = MI(input1, input2)
Results
Histogram
Joint Histogram