edward2
edward2 copied to clipboard
JAX SpectralNormalization shape error with vanilla nn.Dense layer
Applying edward2.jax.nn.SpectralNormalization
to a vanilla Flax nn.Dense
layer fails with a shape error because the kernel is left-multiplied with the input instead of right-multiplied. On line 143 of spectral_norm.py
the current Edward2 code reads:
kernel_apply = lambda x: w.reshape(-1, w.shape[-1]) @ x
whereas the relevant line in Flax's Dense implementation is:
y = lax.dot_general(inputs, kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision)
This leads to shape errors of the form:
TypeError: dot_general requires contracting dimensions to have the same shape, got (64,) and (95,).
Copying the SpectralNormalization
implementation into my project and changing line 143 to read lambda x: x @ w.reshape(-1, w.shape[-1])
seems to resolve the issue.
It seems pretty odd to me that no one has caught this before, so it's possible I'm somehow using the module incorrectly, but if this is expected behavior I think it should be more clearly documented.
Thanks for raising this! I'm not super familiar with the implementation so cc-ing @jereliu.
Hello! A gentle reminder that this issue needs a fix or a proper workaround.
Thanks! Tagging author of JAX SpectralNormalization @fehiepsi
Sorry, I missed the messages. @norabelrose is correct that we need to use lambda x: x @ w.reshape(-1, w.shape[-1])
, which aligns with the expectation that in_shape = (np.prod(w.shape[:-1]),)
. I will incorporate the fix and add more tests to cover this issue.