PDAE icon indicating copy to clipboard operation
PDAE copied to clipboard

Some questions about implementation of unet_shift

Open xinyangATK opened this issue 2 years ago • 2 comments

Thank you so much for releasing your code and I have some questions while reproducing your work. In the forward() function of class ResBlockShift(TimestepZBlock), the out_rest(h) seems set h to zero which doesn't make emb_z effrctive. Is there any problems in this module?

    # Adaptive Group Normalization
    out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
    scale, shift = torch.chunk(emb_out, 2, dim=1)
    z_scale, z_shift = torch.chunk(emb_z_out, 2, dim=1)
    h = (1. + z_scale) * (out_norm(h) * (1. + scale) + shift) + z_shift
    h = out_rest(h)

    return self.skip_connection(x) + h

xinyangATK avatar Apr 01 '23 07:04 xinyangATK

Thanks for your attention!

Dou you mean the zero_module here? https://github.com/ckczzj/PDAE/blob/fbba0355634861196aed8b80b9ba4948ed210ab9/model/module/module.py#L362-L364

It is just a zero-initialization of the output conv layer. The zero-initialization makes the residual block work like an identity function in the beginning of training, which is a commonly-used trick for stable training.

Although the parameters are initilized as zero, their gradient still exist. After the first update of the network, they will be almost none-zero. Recent work ControlNet have similar issues.

ckczzj avatar Apr 01 '23 08:04 ckczzj

Thank you for your patient answer!

This really solved my confusion about this module.

xinyangATK avatar Apr 06 '23 14:04 xinyangATK