consistency_models icon indicating copy to clipboard operation
consistency_models copied to clipboard

Reproducing CIFAR10 result from paper: Improved Techniques For Consistency Training

Open noamelata opened this issue 2 years ago • 37 comments

Hi! Has anyone managed to reproduce the results from "Improved Techniques For Consistency Training" on the CIFAR10 dataset?

Thank you for the great repository!

noamelata avatar Nov 18 '23 17:11 noamelata

Hello!

I am not sure about this. So far I haven't carried out large experiments with this.

Thank you for the kind sentiments.

Kinyugo avatar Nov 18 '23 19:11 Kinyugo

@Kinyugo @noamelata Hi,everyone. I have implemented the experiment on butterflies dataset(len=1000), which the result trained 10000 steps(max_step) is better, FID5=0.2028 with five steps sampling, and FID1 =0.2165 with one step sampling. However, on cifar10, I didn't get good result as I expected. I set : (1) N(k)=11, max_steps=400000, optimizer=RAdam(better than Adam), batchsize=32 , result: FID10(10 step sampling)=74.3594, FID5=61.3095, FID2=127.9843, FID1=285.8623; (2) N(k)=1280, max_steps=400000, optimizer=RAdam , batchsize=32, result: FID10=64.5555,FID5=56.4666, FID2=68.3255, FID1=243.7875; (3)N(k)=1280, max_steps=400000, optimizer=RAdam , batchsize=64, result: FID10=30.2732, FID5=35.0473, FID2,=62.3042 ,FID1=238.0072. If I set the batchsize=128, The results were also poor.

thwgithub avatar Dec 25 '23 13:12 thwgithub

I suspect there is something wrong with the structure of Unet, but i am not sure about it.

thwgithub avatar Dec 25 '23 13:12 thwgithub

Hello, Kindly try the experiments with the unet from the paper. This was just a random unet. If you manage to get it working kindly share your findings.

Kinyugo avatar Dec 25 '23 13:12 Kinyugo

@Kinyugo Hello, thank you for your instant reply. I found a error in your codes as follows : in the timesteps_schedule function, def timesteps_schedule(....): num_timesteps = final_timesteps2 - initial_timesteps2 # the final_timesteps above should be "final_timesteps+1",

cf. page 26 in the Appendices of the paper ''Consistency Models" .

num_timesteps = current_training_step * num_timesteps / total_training_steps
num_timesteps = math.ceil(math.sqrt(num_timesteps + initial_timesteps**2) - 1)

return num_timesteps + 1

thwgithub avatar Dec 25 '23 14:12 thwgithub

Thanks for the nice find. I'll correct asap.

Kinyugo avatar Dec 25 '23 15:12 Kinyugo

@Kinyugo However, this trivial error should not obscure the concision aesthetics of your code, this is an excellent work. As far as i know , your code is first one performed the "improved consistency models" in github.

thwgithub avatar Dec 26 '23 07:12 thwgithub

Thank you for your kind sentiments.

Regarding the model, I agree with you that it might be sub-optimal. A recent paper "Analyzing and Improving the Training Dynamics of Diffusion Models" could be of assistance, though it's based on the EDM paper. I do plan to experiment with the architecture there but I am currently held up. You could also check issue #7 that proposes rescaling of sigmas. I think with some of those changes you might get better results.

Thanks for taking the time to share you findings with me.

Kinyugo avatar Dec 26 '23 08:12 Kinyugo

@Kinyugo Thanks for sharing your findings with me. I have not pay attention to the paper "[Analyzing and Improving the Training Dynamics of Diffusion Models]". Regarding the issue #7 , it is worth to try it. But i want to change firstly your unet structure with ncsn++ ,and then check the result on cifar10. To facilitate communication, we send emails([email protected]) each other.

thwgithub avatar Dec 26 '23 08:12 thwgithub

Awesome. Ping me incase of any questions.

If you replicate the ncsn++ network and get good results consider contributing to the repo.

Kinyugo avatar Dec 26 '23 10:12 Kinyugo

@thwgithub Hello! I was wondering if you have had the opportunity to reproduce the FID on CIFAR10 using the U-Net described in the "Consistency Models" paper. If so, may I kindly inquire about the results you obtained? Thanks in advance!

aiihn avatar Dec 29 '23 14:12 aiihn

@aiihn sure.

thwgithub avatar Dec 29 '23 15:12 thwgithub

@thwgithub Thanks! What FID results have you obtained on cifar10?

aiihn avatar Dec 30 '23 09:12 aiihn

@Kinyugo Hello, I have successfully incorporated the ncsn++ network(https://github.com/openai/consistency_models_cifar10/blob/main/jcm/models/ncsnpp.py) into your consistency model. Unfortunately, I did not still achieve good results. At the same time, I looked repeatedly your codes over, I assured it was no problem. Now, I am quite confused about it. Can you help me to check my codes?

thwgithub avatar Jan 17 '24 08:01 thwgithub

----- 原始邮件 ----- 发件人:Kinyugo @.> 收件人:Kinyugo/consistency_models @.> 抄送人:thwgithub @.>, Comment @.> 主题:Re: [Kinyugo/consistency_models] Reproducing CIFAR10 result from paper: Improved Techniques For Consistency Training (Issue #5) 日期:2023年12月26日 18点03分

Awesome. Ping me incase of any questions. If you replicate the ncsn++ network and get good results consider contributing to the repo.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

thwgithub avatar Jan 17 '24 08:01 thwgithub

@thwgithub could you provide your code for the ncsn++ as well as the training code and hyperparameters

Kinyugo avatar Jan 17 '24 08:01 Kinyugo

@Kinyugo I have sent it to email. Did not you receive it?

thwgithub avatar Jan 17 '24 08:01 thwgithub

@Kinyugo your email

thwgithub avatar Jan 17 '24 08:01 thwgithub

@thwgithub No. Consider creating a repo and inviting me via GitHub

Kinyugo avatar Jan 17 '24 10:01 Kinyugo

@Kinyugo I trained only on one GPU(4090)

thwgithub avatar Jan 17 '24 13:01 thwgithub

@Kinyugo Thanks for your advice. please, see more details : https://github.com/thwgithub/ICT_NCSNpp

thwgithub avatar Jan 17 '24 14:01 thwgithub

@thwgithub any breakthrough? I have been unable to checkout your repo due to time constraints on my end

Kinyugo avatar Jan 25 '24 10:01 Kinyugo

@Kinyugo hey ,I find a tiny error in improved consistency training code given by your repo ,the correct mean of timesteps distribution is -1.1, but you write 1.1,just like blow:

nobodyaaa avatar Mar 07 '24 12:03 nobodyaaa

model = ... # could be our usual unet or any other architecture loss_fn = ... # can be anything; pseudo-huber, l1, mse, lpips e.t.c or a combination of multiple losses optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4, betas=(0.9, 0.995)) # setup your optimizer

Initialize the training module using

improved_consistency_training = ImprovedConsistencyTraining( sigma_min = 0.002, # minimum std of noise sigma_max = 80.0, # maximum std of noise rho = 7.0, # karras-schedule hyper-parameter sigma_data = 0.5, # std of the data initial_timesteps = 10, # number of discrete timesteps during training start final_timesteps = 1280, # number of discrete timesteps during training end lognormal_mean = 1.1, # mean of the lognormal timestep distribution <<<<-----here lognormal_std = 2.0, # std of the lognormal timestep distribution )

for current_training_step in range(total_training_steps): # Zero out Grads optimizer.zero_grad()

# Forward Pass
batch = get_batch()
output = improved_consistency_training(
    student_model,
    batch,
    current_training_step,
    total_training_steps,
    my_kwarg=my_kwarg, # passed to the model as kwargs useful for conditioning
)

# Loss Computation
loss = (pseudo_huber_loss(output.predicted, output.target) * output.loss_weights).mean()


# Backward Pass & Weights Update
loss.backward()
optimizer.step()

nobodyaaa avatar Mar 07 '24 12:03 nobodyaaa

@Kinyugo this will make the early results inaccurate and then make entire training worse

nobodyaaa avatar Mar 07 '24 12:03 nobodyaaa

@nobodyaaa Thanks for catching the error. Fortunately it's a documentation issue and the correct lognormal_mean is used in the ImprovedConsistencyTraining class.

Kinyugo avatar Mar 07 '24 12:03 Kinyugo

@Kinyugo yeah, I noticed that,thanks for answering. And did u repreduce icm results? I have run icm training several times but didn't get result as good as showed inthe paper ,the best of my one-step generation fid in cifar10 is about 60.

nobodyaaa avatar Mar 11 '24 06:03 nobodyaaa

Did you run with the same configuration as the paper?

Kinyugo avatar Mar 11 '24 07:03 Kinyugo

@Kinyugo except neural network which I just use Unet copied from somewhere instead of ncsn++.I have read the CM codes given by openai , and I used c_in and rescaled sigma to get the input of neural network just like they did.Other configuration ,like karras schedule and something else ,I just use your code .

nobodyaaa avatar Mar 11 '24 07:03 nobodyaaa

How many iterations did you train? Batch size e.t.c

Kinyugo avatar Mar 11 '24 07:03 Kinyugo