consistency_models icon indicating copy to clipboard operation
consistency_models copied to clipboard

Rescale of sigmas

Open wubowen416 opened this issue 2 years ago • 5 comments

Hi, nice repo, really appreciate it.

One thing is that in the implementation of Song's consistency models, before inputting sigmas in to the network, there is a rescale: rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44) You can chekck it here: https://github.com/openai/consistency_models/blob/e32b69ee436d518377db86fb2127a3972d0d8716/cm/karras_diffusion.py#L346C58-L346C58

Similarly, in EDM's implementation, there is also a rescale before inputting sigma to the network. c_noise = sigma.log() / 4 The link: https://github.com/NVlabs/edm/blob/62072d2612c7da05165d6233d13d17d71f213fee/training/networks.py#L663C9-L663C34

But I did not find this rescaling in your implementation.

I am aware of that the code for improved consistency model has not released yet, so we really do not know if there is such an operation, what do you think?

wubowen416 avatar Dec 08 '23 09:12 wubowen416

Hello,

Thanks for the nice findings and for your interest in my work. I think the goal is to rescale the values to a range that works well for the choice timestep embedding. For consistency models, they do something quite similar to edm but then scale by 1000, this is because they use sinusoidal embeddings for the timestep. In the case of edm as they use fourier embeddings the output of the rescaling is values close to [-1, 1].

It's a common practice in deep learning to rescale values such that they are within a small range, but in our case we use the raw values in the range [0.02, 80.0]. This is not something I have experimented with and I don't know how it would impact the performance of the model. If you do manage to experiment with it kindly share your findings.

rescaled_sigmas_cm rescaled_sigmas_edm

Kinyugo avatar Dec 08 '23 11:12 Kinyugo

Thank you for your reply and your clear explanation.

I personally found that gradient will sometimes explode, causing the network to output nan, if rescaling is not properly applied (e.g., Song's rescale + Fourier embedding, or no scale + fourier embedding). This is especially severe when using more time steps, i.e., dividing the trajactory more. Based on your explanation, this is expected since the value range may be too large.

Maybe this is related to your experimental findings in the notebook, where you say that time step of 10 yielded better result. I think it is worth trying rescaling + larger time steps.

Anyway, thanks again for your kind response.

wubowen416 avatar Dec 09 '23 04:12 wubowen416

Thank you for taking the time to run experiments and for sharing your findings.

I'll open a PR for this.

Kinyugo avatar Dec 09 '23 05:12 Kinyugo

@Kinyugo @wubowen416 In addition, I have recently discovered something interesting. Although I am almost sure that the Consistency models and Improved Techniques for Consistency Training did not mention anything related to scaling the input of the network, the EDM paper mentions something about it:

image

This can be seen in the EDM repository:

https://github.com/NVlabs/edm/blob/62072d2612c7da05165d6233d13d17d71f213fee/training/networks.py#L662-L665

But suddenly I have found something related in the original OpenAI repository:

https://github.com/openai/consistency_models/blob/e32b69ee436d518377db86fb2127a3972d0d8716/cm/karras_diffusion.py#L334-L349

Can this be related when the problems that are we facing when we try to replicate the results of iCT? #17 #5

javiersgjavi avatar Oct 18 '24 11:10 javiersgjavi

Another hint, even the recent paper of consistency models made easy use this two rescaling factors for the noise and network input:

https://github.com/locuslab/ect/blob/4311059770f54821d151a9b0e1f76770a5f3930e/training/networks.py#L700-L718

javiersgjavi avatar Oct 18 '24 11:10 javiersgjavi