Carla_iLQR_MPC icon indicating copy to clipboard operation
Carla_iLQR_MPC copied to clipboard

Type erros with "ilqr_jax_MPC.py"

Open HouJun19 opened this issue 1 year ago • 5 comments

Hi, I'm trying to run this code in my computer which is ubuntu 18.04 and meet some problems.

pygame 2.1.2 (SDL 2.0.16, Python 3.8.13) Hello from the pygame community. https://www.pygame.org/contribute.html Traceback (most recent call last): File "ilqr_jax_MPC.py", line 164, in jac_l, hes_l, jac_l_final, hes_l_final, jac_f = derivative_init() File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/src/api.py", line 527, in cache_miss out_flat = xla.xla_call( File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1937, in bind return call_bind(self, fun, *args, **params) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1953, in call_bind outs = top_trace.process_call(primitive, fun, tracers, params) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 687, in process_call return primitive.impl(f, *tracers, **params) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/dispatch.py", line 208, in _xla_call_impl compiled_fun = xla_callable(fun, device, backend, name, donated_invars, File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/linear_util.py", line 295, in memoized_fun ans = call(fun, *args) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/dispatch.py", line 257, in _xla_callable_uncached return lower_xla_callable(fun, device, backend, name, donated_invars, False, File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/profiler.py", line 294, in wrapper return func(*args, **kwargs) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/dispatch.py", line 302, in lower_xla_callable jaxpr, out_type, consts = pe.trace_to_jaxpr_final2( File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/profiler.py", line 294, in wrapper return func(*args, **kwargs) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2188, in trace_to_jaxpr_final2 jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2139, in trace_to_subjaxpr_dynamic2 out_tracers = map(trace.full_raise, ans) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/util.py", line 47, in safe_map return list(map(f, *args)) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 415, in full_raise return self.pure(val) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1761, in new_const aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c)) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1163, in get_aval return concrete_aval(x) File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1155, in concrete_aval raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX " jax._src.traceback_util.UnfilteredStackTrace: TypeError: Value <CompiledFunction of <function jacfwd..jacfun at 0x7fe3d73ac670>> with type <class 'jaxlib.xla_extension.CompiledFunction'> is not a valid JAX type

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.



The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "ilqr_jax_MPC.py", line 164, in jac_l, hes_l, jac_l_final, hes_l_final, jac_f = derivative_init()

The jax is 0.3.16 and jaxlib is 0.3.15 with CUDA11.1 and cuDNN8.0.5, do you know how to fix these problemes?

Thank you very much and looking forward to your reply.

HouJun19 avatar Aug 27 '22 10:08 HouJun19

Hi @HouJun19, it seems that your jax and jaxlib version are too high. Could you take a look at this previous issue? https://github.com/YukunXia/Carla_iLQR_MPC/issues/5#issuecomment-992945428

YukunXia avatar Aug 28 '22 02:08 YukunXia

Thanks for your reply! Actually I have run this code successfully on jaxlib 0.1.47,but It turns out ti be very slow to complete one loop which needs more than 2 hours, and it has a warning that :Can not found GPU, fall back to CPU. So I want to use GPU to accelerate the code, while I have tried many versions but all failed. For my GPU is NVIDIA 3060ti, some old version of jaxlib and CUDA are not support for the gpu, when I try the jaxlib 0.1.52 I have the same erro that " cannot import name 'pytree' from 'jaxlib' ". while using newer jaxlib and jax it also turns out some other erros like the origin question, I really don't know how to solve them. Sorry if I have some misunderstanding because I just start using python and still a lack of experience. Looking forward to your reply. Many thanks.

HouJun19 avatar Aug 28 '22 04:08 HouJun19

I ran my code on an old laptop CPU (i7-7700HQ) two years ago, so your machine should be good enough to handle it. I could easily achieve a 10Hz decision frequency and the speed bottleneck was from the algorithm stability or the Carla data quality side. My dynamical model assumes the time interval is 0.1 sec.

Maybe you could try some simpler Jax code on your CPU, and see if it achieves a reasonable speed.

For GPU, whether it accelerates or not depends on the program. My program has a shallow neural network, so you might see a bit of performance gain, but I guess not too much.

YukunXia avatar Aug 28 '22 05:08 YukunXia

I should have made a docker environment so that the tests can be easier, but right now I haven't got enough time to fully revisit this project to upgrade the dependencies.

I didn't expect Jax to evolve this much :(

YukunXia avatar Aug 28 '22 05:08 YukunXia

Thanks a lot! I will try some Jax code again. But the first time I run this on jaxlib it seems every step needs more than 4 sec which is really wired. I just run the ./CarlaUE4.sh then python ilqr_jax_MPC.py I wonder if I have left something behind.T^T

HouJun19 avatar Aug 28 '22 05:08 HouJun19