BalancingGroups
BalancingGroups copied to clipboard
Need help in understanding the logic behind construction of GroupDRO class
Inside the Python script models.py, there is a class named GroupDRO as the piece of the code given below,
class GroupDRO(ERM):
def __init__(self, hparams, dataset):
super(GroupDRO, self).__init__(hparams, dataset)
self.register_buffer(
"q", torch.ones(self.n_classes * self.n_groups).cuda())
def groups_(self, y, g):
idx_g, idx_b = [], []
all_g = y * self.n_groups + g
for g in all_g.unique():
idx_g.append(g)
idx_b.append(all_g == g)
return zip(idx_g, idx_b)
def compute_loss_value_(self, i, x, y, g, epoch):
losses = self.loss(self.network(x), y)
for idx_g, idx_b in self.groups_(y, g):
self.q[idx_g] *= (
self.hparams["eta"] * losses[idx_b].mean()).exp().item()
self.q /= self.q.sum()
loss_value = 0
for idx_g, idx_b in self.groups_(y, g):
loss_value += self.q[idx_g] * losses[idx_b].mean()
return loss_value
In the original paper where the concept of GroupDRO is introduced, "DISTRIBUTIONALLY ROBUST NEURAL NETWORKS FOR GROUP SHIFTS: ON THE IMPORTANCE OF REGULARIZATION FOR WORST-CASE GENERALIZATION". In this paper, the mathematical expression for calculating loss is given as follows,
However, the calculation of loss through your construct of class GroupDRO looks very different from the loss function defined in the above image. Please help me understand why have you redefined the loss function as such.