diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

The density_for_timestep_sampling and loss_weighting for SD3 Training!!!

Open DidiD1 opened this issue 1 year ago • 17 comments

Thanks to Rafie Walker's code we can try to train SD3 models with flow-matching! But some places don't seem to match what's in the paper. Rafie Walker's code is below:

def compute_density_for_timestep_sampling(
    weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
    if weighting_scheme == "logit_normal":
        # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
        u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
        u = torch.nn.functional.sigmoid(u)
    elif weighting_scheme == "mode":
        u = torch.rand(size=(batch_size,), device="cpu")
        u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
    else:
        u = torch.rand(size=(batch_size,), device="cpu")
    return u

def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
    if weighting_scheme == "sigma_sqrt":
        weighting = (sigmas**-2.0).float()
    elif weighting_scheme == "cosmap":
        bot = 1 - 2 * sigmas + 2 * sigmas**2
        weighting = 2 / (math.pi * bot)
    else:
        weighting = torch.ones_like(sigmas)
    return weighting

My question is below:

  1. when weighting_scheme == "mode“, the code only compute the f_mode. If you need to compute 'u', you should some additional operation?
  2. Cos-map seems to compute the weight of timesteps, not the weight of loss?
  3. when we use logit_normal, it based on the RF-setting. So the weight of the loss should be t/(1-t), but the code doesn't compute the weight instead of torch.ones_like(sigmas)?

So I think there need some modify to correctly compute the loss of SD3! Thanks for discussion together!

DidiD1 avatar Aug 02 '24 08:08 DidiD1