lagrangian_nns icon indicating copy to clipboard operation
lagrangian_nns copied to clipboard

Issues when I open the main branch on google colab and use the `base-nn-double-pendulum.ipynb`

Open MariosGkMeng opened this issue 2 years ago • 22 comments

Issues when I open the main branch on google colab and use the base-nn-double-pendulum.ipynb:

  1. cannot import from jax.example_libraries import stax. Fixed it by replacing he line with: from jax.example_libraries import stax
  2. issues with the HyperParameterSearch.py:
    • ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 3) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function _odeint_wrapper is non-hashable. for the function odeint.
The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
[/content/lagrangian_nns/experiment_dblpend/data.py](https://localhost:8080/#) in get_trajectory_analytic(y0, times, **kwargs)
     38 @partial(jax.jit, backend='cpu')
     39 def get_trajectory_analytic(y0, times, **kwargs):
---> 40     return odeint(analytical_fn, y0, t=times, rtol=1e-10, atol=1e-10, **kwargs)
     41 
     42 def get_dataset(seed=0, samples=1, t_span=[0,2000], fps=1, test_split=0.5, **kwargs):

TypeError: odeint() got an unexpected keyword argument 'mxsteps'
  1. When I run loss(get_params(opt_state), batch_data, 0.0) I get:
[/content/lagrangian_nns__modified/hyperopt/HyperparameterSearch.py](https://localhost:8080/#) in dynamics(q, q_t)
     32 #     assert q.shape == (2,)
     33     state = wrap_coords(jnp.concatenate([q, q_t]))
---> 34     return jnp.squeeze(nn_forward_fn(params, state), axis=-1)
     35   return dynamics
     36 

[/usr/local/lib/python3.7/dist-packages/jax/numpy/lax_numpy.py](https://localhost:8080/#) in squeeze(a, axis)
   1165     axis = frozenset(_canonicalize_axis(i, ndim(a)) for i in axis)
   1166     if _any(shape_a[a] != 1 for a in axis):
-> 1167       raise ValueError("cannot select an axis to squeeze out which has size "
   1168                        "not equal to one")
   1169     newshape = [d for i, d in enumerate(shape_a)

ValueError: cannot select an axis to squeeze out which has size not equal to one

I managed to fix this by changing: THIS COMMAND: return jnp.squeeze(nn_forward_fn(params, state), axis=-1) TO THIS COMMAND: return nn_forward_fn(params, state)

MariosGkMeng avatar Nov 24 '22 17:11 MariosGkMeng

Thanks, will fix the issue about moved stax location soon. @greydanus do you want to update your colab notebook?

MilesCranmer avatar Nov 24 '22 19:11 MilesCranmer

What line is issue 2 from?

MilesCranmer avatar Nov 24 '22 19:11 MilesCranmer

Ah, JAX changed mxsteps to mxstep... Why make such a small breaking change I do not know...

MilesCranmer avatar Nov 24 '22 20:11 MilesCranmer

PR in #4 will fix this.

MilesCranmer avatar Nov 24 '22 20:11 MilesCranmer

What line is issue 2 from?

Once I replicate the error (not receiving it right now), I will specify :)

Ah, JAX changed mxsteps to mxstep... Why make such a small breaking change I do not know...

Haha, actually it does ring a bell, but after many different packages and trials, I had forgotten about it. Thanks for fixing 😀

MariosGkMeng avatar Nov 24 '22 20:11 MariosGkMeng

@MilesCranmer I added a 4th issue in my initial comment with a (probably inproper) fix.

MariosGkMeng avatar Nov 24 '22 20:11 MariosGkMeng

For the 4th issue, what is the full traceback? i.e., what chunk of code is it coming from?

You should try this replacement instead:

  def dynamics(q, q_t):
    state = wrap_coords(jnp.concatenate([q, q_t]))
    updated_state = nn_forward_fn(params, state)
    if len(updated_state.shape) == 2:
        return jnp.squeeze(updated_state, axis=-1)
    else:
        return updated_state

Not sure why one of the states has dim=1

MilesCranmer avatar Nov 24 '22 20:11 MilesCranmer

For the 4th issue, what is the full traceback? i.e., what chunk of code is it coming from?

Updated the comment on bullet number 4.

What line is issue 2 from?

I reproduced it just by having the most recent version of jax. See, I had reverted to an older version (0.1.68) and the error was not occuring. But I installed the newest again because 0.1.68 seemed to not be able to find my colab gpu, even though TensorFlow can find it.

MariosGkMeng avatar Nov 24 '22 21:11 MariosGkMeng

@MilesCranmer Also, I noticed that I did not fully reply to your question regarding issue 2:

The issue is triggered when I run from HyperparameterSearch import learned_dynamics

And only when I am using the recent jax version (0.3.25).

MariosGkMeng avatar Nov 26 '22 18:11 MariosGkMeng

I am running JAX 0.4.6. I corrected import statements for stax and optimizers. I replaced mxsteps with mxstep in one place. I change the definition of function dynamic in files train.py and Hyperparametersearch.py. Should the recommended change be made in both places? I am still getting the following error inn cell [4] of DoublePendulum-Baseline.ipynb. The same error appears in other notebooks. error_lagrange.txt

zdjordje123 avatar Mar 26 '23 13:03 zdjordje123

Hi @zdjordje123, do you want to push your edits as a draft pull request and I can help work on it with you? It would be great if the code would be updated for modern JAX!

MilesCranmer avatar Mar 26 '23 19:03 MilesCranmer

I would be glad to help. Unfortunately, I am not sure my edits are very useful. If you give me prices tasks I could work on them. Regards Zoran Djordjevic

zdjordje123 avatar Mar 27 '23 01:03 zdjordje123

I tried creating a virtual environment with JAX 0.1.68, which I see someone mentioned as providing error free runs. I was hoping to go in parallel through both old and new environments and see the differences. Unfortunately, my 0.1.68 JAX environment has no intention to work. I will have time to revisit this code only at the end of May. If someone migrates the code to the newest JAX, I will be happy to use it and reference it. Otherwise, I could try to migrate it myself, then. Best Regards Zoran Djordjevic

zdjordje123 avatar Mar 27 '23 02:03 zdjordje123

Will let you know. Busy time, but hope I can get a chance to update things to new JAX.

MilesCranmer avatar Mar 30 '23 21:03 MilesCranmer

Hi, I am running JAX 0.4.13. I corrected import statements for stax and optimizers, replaced mxsteps with mxstep and changed the definition of function dynamic in files Hyperparametersearch.py. However, when I run from HyperparameterSearch import learned_dynamics, I got these errors: UnexpectedTracerError: Found a JAX Tracer object passed as an argument to a custom_vjp function in a position indicated by nondiff_argnums as non-differentiable. Tracers cannot be passed as non-differentiable arguments to custom_vjp functions; instead, nondiff_argnums should only be used for arguments that can't be or contain JAX tracers, e.g. function-valued arguments. In particular, array-valued arguments should typically not be indicated as nondiff_argnums. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError. How can I fix it, thanks.

xzhuzhu avatar Jul 25 '23 15:07 xzhuzhu

Hmm, I will take a look this weekend. Thanks for the report!!

MilesCranmer avatar Jul 28 '23 17:07 MilesCranmer

(If you find a fix before I can get to it I will gladly accept a PR btw)

MilesCranmer avatar Jul 28 '23 17:07 MilesCranmer

Hi, I am running JAX 0.4.13. I corrected import statements for stax and optimizers, replaced mxsteps with mxstep and changed the definition of function dynamic in files Hyperparametersearch.py. However, when I run from HyperparameterSearch import learned_dynamics, I got these errors: UnexpectedTracerError: Found a JAX Tracer object passed as an argument to a custom_vjp function in a position indicated by nondiff_argnums as non-differentiable. Tracers cannot be passed as non-differentiable arguments to custom_vjp functions; instead, nondiff_argnums should only be used for arguments that can't be or contain JAX tracers, e.g. function-valued arguments. In particular, array-valued arguments should typically not be indicated as nondiff_argnums. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError. How can I fix it, thanks.

Any updates on this issue?

umerhuzaifa avatar Nov 10 '23 17:11 umerhuzaifa

Thanks for the ping and sorry for not finding enough time to fix this yet. Please keep pinging me though, eventually will be able to get it done.

(However, I will also say for anybody reading this that I am always immensely of PRs, and will gladly review + merge it!)

MilesCranmer avatar Nov 24 '23 17:11 MilesCranmer

bump :) I've been trying to solve this problem for a few hours. Can't get it to work! I'm on the "fix-stax-import" branch, all other problems seem fixed there.

Raul-Create avatar May 06 '24 15:05 Raul-Create

One other option is to use https://github.com/astrofrog/pypi-timemachine to install the exact same versions of app dependencies as they existed at publication time. And maybe we can set those in requirements.txt

But yes I’d still like to fix this for recent JAX, I’ve just been pretty time deficient to do things myself.

MilesCranmer avatar May 06 '24 16:05 MilesCranmer

Would someone be willing to test out the code in https://github.com/MilesCranmer/lagrangian_nns/pull/8? I've tried to fix most of the issues caused by JAX updates.

MilesCranmer avatar Jun 25 '24 21:06 MilesCranmer