modulus icon indicating copy to clipboard operation
modulus copied to clipboard

🐛[BUG]: CorrDiff loss is scaled by hyper-parameter

Open chychen opened this issue 7 months ago • 0 comments

Version

latest

On which installation method(s) does this occur?

Source

Describe the issue

CorrDiff loss is scaled by hyper-parameter, therefore we could not make a hyper-parameter search, because each run cannot be compared to the others.

example:

  • if batch_gpu_total = 1, loss_accum = L, when batch_gpu_total = 2, loss_accum = L/2
  • if batch_size_gpu = 1, loss_accum = L, when batch_size_gpu = 2, loss_accum = 2*L

why not just normalize it by batch_size_global? such as below

Now Implementation

      for round_idx in range(num_accumulation_rounds):
            with ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
                ...
                loss = loss.sum().mul(loss_scaling / batch_gpu_total)
                loss_accum += loss / num_accumulation_rounds
                loss.backward()

        loss_sum = torch.tensor([loss_accum], device=device)
        if dist.world_size > 1:
            torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM)
        average_loss = loss_sum / dist.world_size
        if dist.rank == 0:
            wb.log({"training loss": average_loss}, step=cur_nimg)

Proposed Modification

      for round_idx in range(num_accumulation_rounds):
            with ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
                ...
                loss = loss.sum().mul(loss_scaling / batch_size_global) ### Modified
                loss_accum += loss ### Modified
                loss.backward()

        loss_sum = torch.tensor([loss_accum], device=device)
        if dist.world_size > 1:
            torch.distributed.all_reduce(loss_sum, op=torch.distributed.ReduceOp.SUM)
        average_loss = loss_sum / dist.world_size
        if dist.rank == 0:
            wb.log({"training loss": average_loss}, step=cur_nimg)

Minimum reproducible example

see README

Relevant log output

example:
- if `batch_gpu_total` = 1, `loss_accum` = L, when `batch_gpu_total` = 2, `loss_accum` = L/2
- if `batch_size_gpu` = 1, `loss_accum` = L, when `batch_size_gpu` = 2, `loss_accum` = 2*L

Environment details

No response

chychen avatar Jul 17 '24 01:07 chychen