ScoreCAM icon indicating copy to clipboard operation
ScoreCAM copied to clipboard

ScoreCAM fix batch proposal

Open FrancescoMandru opened this issue 2 years ago • 0 comments

I found out that the ScoreCAM code is not set for batched inputs. This is my attempt to fix this thing:

    def forward(self, x):
        with torch.no_grad():
            B, _, H, W = x.size()
            device = x.device
            self.model.zero_grad()
            score = self.model(x)
            prob = torch.nn.functional.softmax(score, dim=1)
            _, idx = torch.max(prob, dim=1)
            prob = prob[:, 1].unsqueeze(-1).detach().float()

            # put activation maps through relu activation
            # because the values are not normalized with eq.(1) without relu.
            self.activations = torch.nn.functional.relu(self.values.activations).to('cpu').clone()
            self.activations = torch.nn.functional.interpolate(self.activations, (H, W), mode='bilinear')
            _, C, _, _ = self.activations.shape
            # normalization
            act_min, _ = self.activations.view(1, C, -1).min(dim=2)
            act_min = act_min.view(1, C, 1, 1)
            act_max, _ = self.activations.view(1, C, -1).max(dim=2)
            act_max = act_max.view(1, C, 1, 1)
            denominator = torch.where((act_max - act_min) != 0., act_max - act_min, torch.tensor(1.))
            self.activations = self.activations / denominator

            # generate masked images and calculate class probabilities
            probs = []
            #random_channels = np.random.randint(C, size=10)
            for i in range(0, C):
                mask = self.activations[:, i, :, :].unsqueeze(1)
                mask = mask.to(device)
                masked_x = x * mask
                score = self.model(masked_x)
                score = torch.nn.functional.softmax(score, dim=1)
                score = torch.amax(score, dim=1).to('cpu').data
                probs.append(score)

            probs = torch.stack(probs)
            weights = probs.view(B, C, 1, 1)

            cam = (weights * self.activations[:, :C, :, :]).sum(1, keepdim=True)
            cam = torch.nn.functional.relu(cam)
            cam -= torch.min(cam)
            cam /= torch.max(cam)

        return cam.data, idx.item()

Let me know if thee is something wrong my solution.

FrancescoMandru avatar Oct 13 '22 13:10 FrancescoMandru