jax icon indicating copy to clipboard operation
jax copied to clipboard

JAX-based gradient descent plateaus

Open tolgarecep opened this issue 1 year ago • 11 comments

Description

I'm writing my own implementation of some numerical solution papers that solves differential equations using machine learning. However, my implementation fails to converge for examples with slightly complex terms: for example, it can solve the one with exp(y), but it can't solve 4*(exp(y)+exp(y/2)). And when my implementation converges, it does so slower than reported results; what they achieve with 100 epochs I achieve with 1000.

Now, perhaps there is a lot to explain in terms of this research to get us to the same page, but here I just want to know how JAX might have been causing the plateau here. I'm suspecting the stacking operation which happens in the first line of the function N(w, x). There I'm just trying to represent a scalar by plugging it to first 5 Chebyshev polynomials, placing each result to a vector respective to the order of the polynomial and getting a 5-dimensional representation.

GPU returns the same result but slower.

import jax
from jax import random, grad, vmap, jit
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

def chebyshev(n, x):
  if n == 0:
    return jnp.ones_like(x)
  elif n == 1:
    return x
  else:
    return 2*x*chebyshev(n-1, x) - chebyshev(n-2, x)

def N(w, x):
  # numerical transformation
  x = jnp.stack([chebyshev(i, x) for i in range(5)]).T
  x = x.T if x.shape[0]==1 else x
  # learning
  z = jnp.dot(w, x).squeeze()
  return jnp.tanh(z)

def trial(w, x):
  return 1 + (x**2)*N(w, x)

trial_vect = vmap(trial, (None, 0))
grad_trial = grad(trial, 1)
grad_trial_vect = vmap(grad_trial, (None, 0))
grad2_trial = grad(grad_trial, 1)
grad2_trial_vect = vmap(grad2_trial, (None, 0))

inputs = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.])

@jit
def error(params, inputs):
	term = grad2_trial_vect(params, inputs) + (2/inputs)*grad_trial_vect(params, inputs) + 4*(2*jnp.exp(trial_vect(params, inputs))+jnp.exp(trial_vect(params, inputs)/2))
  	return jnp.sum(0.5*term**2)

grad_error = jit(grad(error, 0))

key = random.PRNGKey(0)
params = random.normal(key, shape=(1, 5))

epochs = 1000001
lr = 0.0000005

for epoch in range(epochs):
    if epoch % 100000  == 0:
      print('epoch: %3d error: %.6f' % (epoch, error(params, inputs)))
    grads = grad_error(params, inputs)
    params = params - lr*grads

Result:

epoch:   0 error: 5474.779454
epoch: 100000 error: 1281.560668
epoch: 200000 error: 1279.912948
epoch: 300000 error: 1279.364945
epoch: 400000 error: 1279.091227
epoch: 500000 error: 1278.927101
epoch: 600000 error: 1278.817733
epoch: 700000 error: 1278.739640
epoch: 800000 error: 1278.681087
epoch: 900000 error: 1278.635557
epoch: 1000000 error: 1278.599140

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.23 jaxlib: 0.4.23 numpy: 1.25.2 python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1

tolgarecep avatar Mar 12 '24 13:03 tolgarecep

If the problem is slow convergence, I would check the learning rate. 0.0000005 seems very small.

jakevdp avatar Mar 12 '24 13:03 jakevdp

If the problem is slow convergence, I would check the learning rate. 0.0000005 seems very small.

Oh, I tried many learning rates, indeed I thought lr decay would solve it, but it didn't. Then I forgot to run with the largest appropriate one before sharing. Anyway, with .5 it converges to error: 1278.272054 then it's stuck there.

tolgarecep avatar Mar 12 '24 14:03 tolgarecep

Without digging further, my guess would be that you've found a local optimum.

jakevdp avatar Mar 12 '24 14:03 jakevdp

Without digging further, my guess would be that you've found a local optimum.

Paper omits training details, so this is possible. I'd like to discuss here if this implementation of hard-coded representation is correct; does it hurt the backpropagation or is it fine?

  # numerical transformation
  x = jnp.stack([chebyshev(i, x) for i in range(5)]).T
  x = x.T if x.shape[0]==1 else x

tolgarecep avatar Mar 12 '24 15:03 tolgarecep

That formulation should be fine when it comes to autodiff. The only issue I see is that the recursive chebyshev implementation you're using is not as efficient as it could be: I suspect that the compiler optimizes it well enough, but it will generate a lot of duplicate statements and lead to slow compile times. It shouldn't affect the numerical results though.

jakevdp avatar Mar 12 '24 16:03 jakevdp

(that said, I've not looked at the paper you're trying to implement, so I can't comment on whether what you're doing is similar to what it is doing)

jakevdp avatar Mar 12 '24 16:03 jakevdp

Is the gradient norm approximately zero, i.e. do the iterates stop moving (in addition to the loss not progressing)?

To rule out numerical approximation issues, you could try comparing to Autograd, since it is close to a drop-in replacement (no vmap though) and it'd use NumPy for numerical operations.

Can you link the paper in question?

mattjj avatar Mar 12 '24 18:03 mattjj

Also, another thing I noticed: if you change the initial params state by using a different seed for random.key, you end up with a different answer (also with zero gradient). This strongly suggests that you have a multimodal likelihood space, and you're landing in a local rather than a global optimimum. One way around this would be to use more sophisticated optimization methods; see for example the jaxopt package.

jakevdp avatar Mar 12 '24 18:03 jakevdp

Is the gradient norm approximately zero, i.e. do the iterates stop moving (in addition to the loss not progressing)?

To rule out numerical approximation issues, you could try comparing to Autograd, since it is close to a drop-in replacement (no vmap though) and it'd use NumPy for numerical operations.

Can you link the paper in question?

Tuning the learning rate along training, for the above equation (loss function) I got error: 70.535873 grad norm: 0.010273, so yes, they vanish (Additional question here: I knew about learning rate decay, but I had never increased learning rates during training before, yet increasing it got me out of local optimums here. But if I had started with that learning rate, loss would return nan. Any source on this?)

This particular equation is not solved in the paper proposing this architecture, but I expect it to solve it, at least reach an error < 0. Except for final two examples, I can replicate the results. Here is the paper: Chebyshev Neural Network based model for solving Lane–Emden type equations

Also, another thing I noticed: if you change the initial params state by using a different seed for random.key, you end up with a different answer (also with zero gradient). This strongly suggests that you have a multimodal likelihood space, and you're landing in a local rather than a global optimimum. One way around this would be to use more sophisticated optimization methods; see for example the jaxopt package.

Few experiments with Adam resulted in same kind of behavior. And that one paper I mentioned, where it reported convergence in 100 epochs but my JAX implementation did after 1000, it reported using vanilla GD and so that's how I did it. I'll look into it though, thanks!

tolgarecep avatar Mar 12 '24 19:03 tolgarecep

Additional question here: I knew about learning rate decay, but I had never increased learning rates during training before, yet increasing it got me out of local optimums here. But if I had started with that learning rate, loss would return nan. Any source on this?

I think it depends on the problem at hand, but in some problems choosing good initializers is difficult and so if we "warm up" (ie increase) the step size at first we can get into better optimization regimes. I think that is common in deep learning: see e.g. this paper and references, though I don't know what the canonical references are. On the theory side, the work of Altschuler and Parrilo like here and here (and Altschuler's 2015 masters thesis under Parrilo) offer some explanation of non-monotonic, not-just-constant-or-decaying step sizes.

Should we keep this issue open, or close it until new questions arise? (Sorry, I'm just not sure if there are outstanding JAX questions at the moment!)

mattjj avatar Mar 12 '24 19:03 mattjj

Additional question here: I knew about learning rate decay, but I had never increased learning rates during training before, yet increasing it got me out of local optimums here. But if I had started with that learning rate, loss would return nan. Any source on this?

I think it depends on the problem at hand, but in some problems choosing good initializers is difficult and so if we "warm up" (ie increase) the step size at first we can get into better optimization regimes. I think that is common in deep learning: see e.g. this paper and references, though I don't know what the canonical references are. On the theory side, the work of Altschuler and Parrilo like here and here (and Altschuler's 2015 masters thesis under Parrilo) offer some explanation of non-monotonic, not-just-constant-or-decaying step sizes.

Should we keep this issue open, or close it until new questions arise? (Sorry, I'm just not sure if there are outstanding JAX questions at the moment!)

I think we can close since there is no clear evidence that the problem is JAX-related. Not sure if we must consider it "completed" though. Thanks for quick responses.

tolgarecep avatar Mar 13 '24 17:03 tolgarecep