lightning-hydra-template icon indicating copy to clipboard operation
lightning-hydra-template copied to clipboard

bug: log metrics in DDP mode

Open zhiyuanpeng opened this issue 2 years ago • 0 comments

Hi ashleve,

I find this line logs acc without setting sync_dist=True and may log unexcepted values in DDP mode. The correct way in my understanding is:

self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)

The torchmetrics document gives two methods to log metrics:

class MyModule(LightningModule):

    def __init__(self):
        ...
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        self.train_acc(preds, y)
        self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        ...
        self.valid_acc(logits, y)
        self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)

and

class MyModule(LightningModule):

    def __init__(self):
        ...
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        batch_value = self.train_acc(preds, y)
        self.log('train_acc_step', batch_value)

    def training_epoch_end(self, outputs):
        self.train_acc.reset()

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        ...
        self.valid_acc.update(logits, y)

    def validation_epoch_end(self, outputs):
        self.log('valid_acc_epoch', self.valid_acc.compute())
        self.valid_acc.reset()

The log in method 1 training_step will take care of synchronizing automatically in DDP mode. The log in method 2 only log the valuse returned by torchmetrics, not the torchmetrics object, so,validation_epoch_end needs to set sync_dist=True to handle the metrics synchronization.

The first method works in both DDP and non-DDP mode, but the second one only works in non-DDP mode. It's better to use first method to log torchmetrics object, so there is no need to change the log method when switch to DDP mode.

I asked this question in pytorch_lightning channel and Rohit clarify that logging the value instead of torchmetrics object will need to set sync_dist=True in ddp mode.

zhiyuanpeng avatar Jul 27 '22 02:07 zhiyuanpeng

@zhiyuanpeng Thank you for pointing this out! Here's a proposition how to fix the issue: https://github.com/ashleve/lightning-hydra-template/pull/426

ashleve avatar Aug 24 '22 00:08 ashleve