zuko icon indicating copy to clipboard operation
zuko copied to clipboard

Batch normalization

Open jmm34 opened this issue 11 months ago • 2 comments

I've found the Zuko library to be extremely beneficial for my work. I sincerely appreciate the effort that has gone into its development. In the Masked Autoregressive Flow paper (NeurIPS, 2017), the authors incorporated batch normalization following each autoregressive layer. Could this modification be integrated into the MaskedAutoregressiveTransform function?

jmm34 avatar Mar 21 '24 16:03 jmm34

Hello @jmm34, thanks for the kind words.

I am not a fan of batch normalization as it often leads to train/test gaps which are hard to diagnose, but I see why one would want to use it (mainly faster training).

IMO the best way to add batch normalization in Zuko would be to implement a standalone (lazy) BatchNormTranform. The user can then insert batch norm transformations anywhere in the flow.

We would accept a PR that implements this.

Edit: I think that using the current batch statistics to normalize is invalid as it would not be an invertible transformation $y = f(x)$ (impossible to know $x$ given $y$). So, we should use running statistics both during training and evaluation, and update these statistics during training. Also, I am not sure that the scale and shift parameters are relevant (mean zero, unit variance is the target).

francois-rozet avatar Mar 21 '24 17:03 francois-rozet

Dear @francois-rozet, thank you very much for your quick reply. I will try to make a PR using the strategy you suggest.

jmm34 avatar Mar 21 '24 17:03 jmm34