dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

New interface for spectral normalization

Open shoyer opened this issue 4 years ago • 3 comments

I noticed earlier today that Haiku has SpectralNormalization -- very cool!

I'm interested in implementing an improved version, which does a much better job estimating the norm for convolutional layers and should converge to the correct answer for any linear operator. The trick is to use auto-diff to calculate the transpose of the linear operator. In contrast, the current implementation is only accurate for dense matrices.

Here's my implementation in pure JAX: https://nbviewer.jupyter.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc

My question: how can I implement this in Haiku?

  • It feel like the right way to write this would be as a Module that takes another Module (or function) as an argument, but I don't know of any existing prior art for that. Would does that make sense to you?
  • How do I call jax.vjp on Module? I'm guessing (though to be honest I haven't checked yet) that normal JAX function would break, given the way that Haiku adds mutable state.

shoyer avatar Sep 02 '20 07:09 shoyer

Hi! I wrote the current Haiku implementation. I'll answer the questions in reverse order :)

How do I call jax.vjp on module?

If you look at https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/stateful.py you can see how we define a bunch of wrappers around Jax functions to work with Haiku. There's a lot of code but the idea is simple: temporarily grab the global state and thread it through the Jax function inputs, then make it global state again within the function (and reverse when returning). We don't have a wrapper around vjp right now (we have one for grad), but it shouldn't be too hard to do.

how can I implement this in Haiku?

I definitely think it would be good to have this function somewhere, but I'm a little hesitant on suggesting where. If I am imagining your implementation correctly, it doesn't actually require any Haiku state or Parameters (unlike hk.SpectralNorm which uses state to store a running estimate of the spectral values). Would it be better to have it be a pure function elsewhere, and have examples on how you could use it with a Haiku Module (along with a Flax/etc modules, since presumably they'd all work)?

As a second question, has this approach been used before/do you know how well it works on GPUs/TPUs? The approximation we use is that used by SNGAN (https://arxiv.org/pdf/1802.05957.pdf) and BigGAN, and we know it remains quite stable on accelerators, I'd be curious if you have run any experiments checking the exact approach's numerics.

Cogitans avatar Sep 02 '20 11:09 Cogitans

I definitely think it would be good to have this function somewhere, but I'm a little hesitant on suggesting where. If I am imagining your implementation correctly, it doesn't actually require any Haiku state or Parameters (unlike hk.SpectralNorm which uses state to store a running estimate of the spectral values). Would it be better to have it be a pure function elsewhere, and have examples on how you could use it with a Haiku Module (along with a Flax/etc modules, since presumably they'd all work)?

For use in neural network training, I think you would still want to estimate the vector corresponding to the largest singular in an online fashion.

Here's a clearer way to separate the logic:


def _l2_normalize(x, eps=1e-4):
  return x * jax.lax.rsqrt((x ** 2).sum() + eps)

def _l2_norm(x):
  return jnp.sqrt((x ** 2).sum())

def _power_iteration(A, u, n_steps=10):
  """Update an estimate of the first right-singular vector of A()."""
  def fun(u, _):
    v, A_transpose = jax.vjp(A, u)
    u, = A_transpose(v)
    u = _l2_normalize(u)
    return u, None
  u, _ = lax.scan(fun, u, xs=None, length=n_steps)
  return u

def estimate_spectral_norm(f, x, seed=0, n_steps=10):
  """Estimate the spectral norm of f(x) linearized at x."""
  rng = jax.random.PRNGKey(seed)
  u0 = jax.random.normal(rng, x.shape)
  _, f_jvp = jax.linearize(f, x)
  u = _power_iteration(f_jvp, u0, n_steps)
  sigma = _l2_norm(f_jvp(u))
  return sigma

I can imagine estimate_spectral_norm being a separately useful utility, but in a spectral normalization layer, you'd want to save the vector u0 as state on the layer and only use a handful of power iterations in each neural net evaluation.

As a second question, has this approach been used before/do you know how well it works on GPUs/TPUs?

The same approach (but written in a much more awkward/manual way) was used in this ICLR 2019 paper. Numerically, they should be identical. If you're using fully-connected layers, the calculation is exactly the same as the older method, just using autodiff instead of explicit matrix/vector products.

From a fundamental perspective I would guess this is quite efficient and numerically stable on accelerators, because the operation is uses are the exact same as those used at the core of neural net training:

  • forward evaluation of linear layers (e.g., "convolution")
  • gradient evaluation of linear layers (e.g., "convolution transpose")

The cost of doing a single power iteration is thus roughly equivalent to that of pushing a single additional example through the neural net.

(The version I wrote in this comment is slightly simpler that the version in the ICRL 2019 paper, because it uses norm(A(u)) rather v @ A(u) to calculate the singular value and only normalizes once per iteration, but I doubt those make much of a difference and are not hard to change.)

shoyer avatar Sep 02 '20 17:09 shoyer

hi @Cogitans, I'm trying to add spectral normalization into Flax and am modeling it after the Haiku version. I had some questions:

  • How is this used in a typical training loop? Are the params spectral normalized after the gradient update using SNParamsTree (as seen in page 6, Algorithm 1, line 5 of the original paper)? If so, why not just create a helper function that does the spectral normalization and then use jax.tree_map to spectral normalize the params? e.g. params = jax.tree_map(lambda x: spectral_normalize(x), params)
  • I'd like to understand the power iteration method used; it's different than what I've read on the wikipedia article, and seems to not converge to the right eigenvalue even after many iteration steps for some matrices
  • why is lax.stop_gradient used?
  • why do we need to keep track of the state u0 and sigma?

chiamp avatar Sep 14 '23 21:09 chiamp