diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

The density_for_timestep_sampling and loss_weighting for SD3 Training!!!

Open DidiD1 opened this issue 1 year ago • 17 comments

Thanks to Rafie Walker's code we can try to train SD3 models with flow-matching! But some places don't seem to match what's in the paper. Rafie Walker's code is below:

def compute_density_for_timestep_sampling(
    weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
    if weighting_scheme == "logit_normal":
        # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
        u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
        u = torch.nn.functional.sigmoid(u)
    elif weighting_scheme == "mode":
        u = torch.rand(size=(batch_size,), device="cpu")
        u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
    else:
        u = torch.rand(size=(batch_size,), device="cpu")
    return u

def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
    if weighting_scheme == "sigma_sqrt":
        weighting = (sigmas**-2.0).float()
    elif weighting_scheme == "cosmap":
        bot = 1 - 2 * sigmas + 2 * sigmas**2
        weighting = 2 / (math.pi * bot)
    else:
        weighting = torch.ones_like(sigmas)
    return weighting

My question is below:

  1. when weighting_scheme == "mode“, the code only compute the f_mode. If you need to compute 'u', you should some additional operation?
  2. Cos-map seems to compute the weight of timesteps, not the weight of loss?
  3. when we use logit_normal, it based on the RF-setting. So the weight of the loss should be t/(1-t), but the code doesn't compute the weight instead of torch.ones_like(sigmas)?

So I think there need some modify to correctly compute the loss of SD3! Thanks for discussion together!

DidiD1 avatar Aug 02 '24 08:08 DidiD1

honestly none of the weighting tricks really seem relevant to finetuning SD3. not using the timestep weighting has better results.

bghira avatar Aug 02 '24 14:08 bghira

honestly none of the weighting tricks really seem relevant to finetuning SD3. not using the timestep weighting has better results.

Could u give some more details, thanks a lot

xiao2mo avatar Aug 07 '24 02:08 xiao2mo

yes, if you look at the timestep selection distribution using the SD3 style training, it effectively does not ever train the 900-1000 or 0-100 range of timesteps. they are just ignored:

image

ignoring the gaps in the chart here (wandb was having issues) the timestep selection at the end is where i switched to uniform sampling and the model started learning composition and details properly

bghira avatar Aug 07 '24 03:08 bghira

This phenomenon was mentioned in the SD3 paper,maybe why they proposed 'mode sampling with heavy-tails' time-sampling method. However it's strange that in their experiment results 'log-norm' is much better the 'mode' and uniform sampling. So I guess that maybe the different sampling method has their special advantages which needs experiment to valid which one is suitable for own task.

DidiD1 avatar Aug 07 '24 03:08 DidiD1

image

bghira avatar Aug 07 '24 03:08 bghira

it just needs an absolutely enormous batch size for these to make sense.

edit: also worth noting these parameters are likely dependent on model size, the same way LR scales with model size when not using microsoft/mup

bghira avatar Aug 07 '24 03:08 bghira

Thanks a lot. And for my question3: "when we use logit_normal, it based on the RF-setting. So the weight of the loss should be t/(1-t), but the code doesn't compute the weight instead of torch.ones_like(sigmas)?" Do I need to modify the loss weight?

DidiD1 avatar Aug 07 '24 03:08 DidiD1

yes, if you look at the timestep selection distribution using the SD3 style training, it effectively does not ever train the 900-1000 or 0-100 range of timesteps. they are just ignored:

image

ignoring the gaps in the chart here (wandb was having issues) the timestep selection at the end is where i switched to uniform sampling and the model started learning composition and details properly

Thanks a lot

xiao2mo avatar Aug 22 '24 02:08 xiao2mo

yes, if you look at the timestep selection distribution using the SD3 style training, it effectively does not ever train the 900-1000 or 0-100 range of timesteps. they are just ignored:

image

ignoring the gaps in the chart here (wandb was having issues) the timestep selection at the end is where i switched to uniform sampling and the model started learning composition and details properly

@bghira Hi bghira~ I'd like to know when you try the "SD3 style training (lognorm sampling)" or "uniform sampling", what is the difference between the training loss? When you switched to uniform sampling, did it help to lower the loss curve? In my uniform training, these is still some artifacts in the generated image, so I wonder which part in the noise sampling is important to improve this problem? Want to hear your insights, Thanks~

ivylilili avatar Aug 23 '24 08:08 ivylilili

currently we're using sigmoid sampling for timesteps which seems fine but no one has really ablated whether it leaves fine details out

bghira avatar Aug 23 '24 12:08 bghira

currently we're using sigmoid sampling for timesteps which seems fine but no one has really ablated whether it leaves fine details out

Actually, sigmoid and lognorm are mathematically equivalent. But I'm curious why existing open source training implementations don't use timeshift during training, but SD3 paper does.

image

culeao avatar Aug 27 '24 03:08 culeao

currently we're using sigmoid sampling for timesteps which seems fine but no one has really ablated whether it leaves fine details out

Actually, sigmoid and lognorm are mathematically equivalent. But I'm curious why existing open source training implementations don't use timeshift during training, but SD3 paper does.

image

In fact, the diffusers version for SD3 has used the timashifting, You can see it in the init of FlowMatchEulerDiscreteScheduler, { "_class_name": "FlowMatchEulerDiscreteScheduler", "_diffusers_version": "0.29.0.dev0", "num_train_timesteps": 1000, "shift": 3.0 }

shcedule

sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)

DidiD1 avatar Aug 27 '24 09:08 DidiD1

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Sep 20 '24 15:09 github-actions[bot]

Thanks a lot. And for my question3: "when we use logit_normal, it based on the RF-setting. So the weight of the loss should be t/(1-t), but the code doesn't compute the weight instead of torch.ones_like(sigmas)?" Do I need to modify the loss weight?

@DidiD1 Hi! Thank you for your valuable discussion about timestep schedule. It really helps me. BTW do you find any answer of this question?

jjihwan avatar Feb 03 '25 07:02 jjihwan

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Feb 27 '25 15:02 github-actions[bot]

Thanks a lot. And for my question3: "when we use logit_normal, it based on the RF-setting. So the weight of the loss should be t/(1-t), but the code doesn't compute the weight instead of torch.ones_like(sigmas)?" Do I need to modify the loss weight?

Thank you for your insightful discussion. Do you find any answer to the question? If training is fine with torch.ones_like(sigmas), does this mean that the loss weighting doesn't make much difference to the result?

wgsxm avatar Apr 22 '25 19:04 wgsxm

Re-sampling timesteps has the similar effects as changing the weight of loss.

CyberPegasus avatar Jun 10 '25 06:06 CyberPegasus