WOODS icon indicating copy to clipboard operation
WOODS copied to clipboard

Update ANDMask Problem

Open khan-yin opened this issue 10 months ago • 1 comments

Hello, author. I want to add ANDMask for benchmark, Well I met a problem when I run for the LSA64 dataset. Could you please check out if the ANDMask code is right and how to solve in LSA64, while the other datasets I test did not see this error.

Code

class ANDMask(ERM):
    """
    Learning Explanations that are Hard to Vary [https://arxiv.org/abs/2009.00329]
    AND-Mask implementation from [https://github.com/gibipara92/learning-explanations-hard-to-vary]
    """

    def __init__(self, model, dataset, optimizer, hparams):
        super(ANDMask, self).__init__(model, dataset, optimizer, hparams)

        # Hyper parameters
        self.tau = self.hparams['tau']

    def mask_grads(self, tau, gradients, params):

        for param, grads in zip(params, gradients):
            grads = torch.stack(grads, dim=0)
            grad_signs = torch.sign(grads)
            mask = torch.mean(grad_signs, dim=0).abs() >= self.tau
            mask = mask.to(torch.float32)
            avg_grad = torch.mean(grads, dim=0)

            mask_t = (mask.sum() / mask.numel())
            param.grad = mask * avg_grad
            param.grad *= (1. / (1e-10 + mask_t))

    def update(self):
        X, Y = self.dataset.get_next_batch()

        out, out_features = self.predict(X)
        n_domains = self.dataset.get_nb_training_domains()
        out, labels = self.dataset.split_tensor_by_domains(out, Y, n_domains)

        # Compute loss for each environment 
        env_losses = torch.zeros(out.shape[0]).to(self.device)
        for i in range(out.shape[0]):
            for t_idx in range(out.shape[2]):     # Number of time steps
                env_losses[i] += F.cross_entropy(out[i, :, t_idx, :], labels[i,:,t_idx]) 

        # Compute gradients for each env
        param_gradients = [[] for _ in self.model.parameters()]
        for env_loss in env_losses:
            env_grads = autograd.grad(env_loss, self.model.parameters(), retain_graph=True)
            for grads, env_grad in zip(param_gradients, env_grads):
                grads.append(env_grad)
            
        # Back propagate
        self.optimizer.zero_grad()
        self.mask_grads(self.tau, param_gradients, self.model.parameters())
        self.optimizer.step()

Error for LSA64

image

khan-yin avatar May 04 '24 15:05 khan-yin