WOODS
WOODS copied to clipboard
Update ANDMask Problem
Hello, author. I want to add ANDMask for benchmark, Well I met a problem when I run for the LSA64 dataset. Could you please check out if the ANDMask code is right and how to solve in LSA64, while the other datasets I test did not see this error.
Code
class ANDMask(ERM):
"""
Learning Explanations that are Hard to Vary [https://arxiv.org/abs/2009.00329]
AND-Mask implementation from [https://github.com/gibipara92/learning-explanations-hard-to-vary]
"""
def __init__(self, model, dataset, optimizer, hparams):
super(ANDMask, self).__init__(model, dataset, optimizer, hparams)
# Hyper parameters
self.tau = self.hparams['tau']
def mask_grads(self, tau, gradients, params):
for param, grads in zip(params, gradients):
grads = torch.stack(grads, dim=0)
grad_signs = torch.sign(grads)
mask = torch.mean(grad_signs, dim=0).abs() >= self.tau
mask = mask.to(torch.float32)
avg_grad = torch.mean(grads, dim=0)
mask_t = (mask.sum() / mask.numel())
param.grad = mask * avg_grad
param.grad *= (1. / (1e-10 + mask_t))
def update(self):
X, Y = self.dataset.get_next_batch()
out, out_features = self.predict(X)
n_domains = self.dataset.get_nb_training_domains()
out, labels = self.dataset.split_tensor_by_domains(out, Y, n_domains)
# Compute loss for each environment
env_losses = torch.zeros(out.shape[0]).to(self.device)
for i in range(out.shape[0]):
for t_idx in range(out.shape[2]): # Number of time steps
env_losses[i] += F.cross_entropy(out[i, :, t_idx, :], labels[i,:,t_idx])
# Compute gradients for each env
param_gradients = [[] for _ in self.model.parameters()]
for env_loss in env_losses:
env_grads = autograd.grad(env_loss, self.model.parameters(), retain_graph=True)
for grads, env_grad in zip(param_gradients, env_grads):
grads.append(env_grad)
# Back propagate
self.optimizer.zero_grad()
self.mask_grads(self.tau, param_gradients, self.model.parameters())
self.optimizer.step()