lightning-hydra-template
lightning-hydra-template copied to clipboard
bug: log metrics in DDP mode
Hi ashleve,
I find this line logs acc without setting sync_dist=True
and may log unexcepted values in DDP mode. The correct way in my understanding is:
self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
The torchmetrics document gives two methods to log metrics:
class MyModule(LightningModule):
def __init__(self):
...
self.train_acc = torchmetrics.Accuracy()
self.valid_acc = torchmetrics.Accuracy()
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
self.train_acc(preds, y)
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
and
class MyModule(LightningModule):
def __init__(self):
...
self.train_acc = torchmetrics.Accuracy()
self.valid_acc = torchmetrics.Accuracy()
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
batch_value = self.train_acc(preds, y)
self.log('train_acc_step', batch_value)
def training_epoch_end(self, outputs):
self.train_acc.reset()
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc.update(logits, y)
def validation_epoch_end(self, outputs):
self.log('valid_acc_epoch', self.valid_acc.compute())
self.valid_acc.reset()
The log in method 1 training_step
will take care of synchronizing automatically in DDP mode. The log in method 2 only log the valuse returned by torchmetrics, not the torchmetrics object, so,validation_epoch_end
needs to set sync_dist=True
to handle the metrics synchronization.
The first method works in both DDP and non-DDP mode, but the second one only works in non-DDP mode. It's better to use first method to log torchmetrics object, so there is no need to change the log method when switch to DDP mode.
I asked this question in pytorch_lightning channel and Rohit
clarify that logging the value instead of torchmetrics object will need to set sync_dist=True
in ddp mode.
@zhiyuanpeng Thank you for pointing this out! Here's a proposition how to fix the issue: https://github.com/ashleve/lightning-hydra-template/pull/426