Pytorch-PCGrad
Pytorch-PCGrad copied to clipboard
reduction is always 'mean'
Setting reduction = 'sum' does not work because of this line:
if self._reduction:
merged_grad[shared] = torch.stack([g[shared]
for g in pc_grad]).mean(dim=0)
because if reduction is a string, self._reduction is always True
See the forked version here with the issue corrected:
https://github.com/anzeyimana/Pytorch-PCGrad-GradVac-AMP-GradAccum