edward2 icon indicating copy to clipboard operation
edward2 copied to clipboard

JAX SpectralNormalization shape error with vanilla nn.Dense layer

Open norabelrose opened this issue 2 years ago • 1 comments

Applying edward2.jax.nn.SpectralNormalization to a vanilla Flax nn.Dense layer fails with a shape error because the kernel is left-multiplied with the input instead of right-multiplied. On line 143 of spectral_norm.py the current Edward2 code reads:

kernel_apply = lambda x: w.reshape(-1, w.shape[-1]) @ x

whereas the relevant line in Flax's Dense implementation is:

y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),
                        precision=self.precision)

This leads to shape errors of the form:

TypeError: dot_general requires contracting dimensions to have the same shape, got (64,) and (95,).

Copying the SpectralNormalization implementation into my project and changing line 143 to read lambda x: x @ w.reshape(-1, w.shape[-1]) seems to resolve the issue.

It seems pretty odd to me that no one has caught this before, so it's possible I'm somehow using the module incorrectly, but if this is expected behavior I think it should be more clearly documented.

norabelrose avatar May 27 '22 02:05 norabelrose

Thanks for raising this! I'm not super familiar with the implementation so cc-ing @jereliu.

dustinvtran avatar May 28 '22 00:05 dustinvtran

Hello! A gentle reminder that this issue needs a fix or a proper workaround.

hr0nix avatar Aug 16 '22 14:08 hr0nix

Thanks! Tagging author of JAX SpectralNormalization @fehiepsi

jereliu avatar Aug 16 '22 16:08 jereliu

Sorry, I missed the messages. @norabelrose is correct that we need to use lambda x: x @ w.reshape(-1, w.shape[-1]), which aligns with the expectation that in_shape = (np.prod(w.shape[:-1]),). I will incorporate the fix and add more tests to cover this issue.

fehiepsi avatar Aug 16 '22 17:08 fehiepsi