FlowModels icon indicating copy to clipboard operation
FlowModels copied to clipboard

student_loss.backward() in LADD

Open jzhang38 opened this issue 1 year ago • 3 comments

student_loss.backward()
torch.nn.utils.clip_grad_norm_(student_unet.parameters(), 1.0)
student_optimizer.step()
student_scheduler.step()
student_optimizer.zero_grad()

Wouldn' t above code generate gradients on the discriminator as well? Then in the next training iter, those gradients on the discriminator will be used to in optimizer.step . I think we need a discriminator_optimizer.zero_grad() after student_optimizer.zero_grad() ?

jzhang38 avatar Nov 23 '24 20:11 jzhang38

    x_1_approx_noised = (1 - reshape_t(renoise_timesteps)) * x_1_approx + reshape_t(renoise_timesteps) * x_0_latent

I believe this line is wrong. Correct version should be"

    x_1_approx_noised =  reshape_t(renoise_timesteps) * x_1_approx + ( 1 - reshape_t(renoise_timesteps)) * x_0_latent

jzhang38 avatar Nov 23 '24 22:11 jzhang38

    x_1_approx_noised = (1 - reshape_t(renoise_timesteps)) * x_1_approx + reshape_t(renoise_timesteps) * x_0_latent

I believe this line is wrong. Correct version should be"

    x_1_approx_noised =  reshape_t(renoise_timesteps) * x_1_approx + ( 1 - reshape_t(renoise_timesteps)) * x_0_latent

Well, this is correct, however the noising here was done in "reverse" it was done to follow the LADD SD3 discriminator Noising process.

I did Logitnormal(1, 1), however if I followed $t * x_1 + (1 - t) x_0$ i should have reverted the logit normal distribution and make it Logitnormal(-1, 1), but i wanted to do like in the paper. Sorry for the confusion

leffff avatar Nov 26 '24 19:11 leffff

student_loss.backward()
torch.nn.utils.clip_grad_norm_(student_unet.parameters(), 1.0)
student_optimizer.step()
student_scheduler.step()
student_optimizer.zero_grad()

Wouldn' t above code generate gradients on the discriminator as well? Then in the next training iter, those gradients on the discriminator will be used to in optimizer.step . I think we need a discriminator_optimizer.zero_grad() after student_optimizer.zero_grad() ?

This may be true! not sure yet. Don't have time to check now. At such small scale does not affect the memory consumption

leffff avatar Nov 26 '24 19:11 leffff