Hi, I'm trying to run this code in my computer which is ubuntu 18.04 and meet some problems.
pygame 2.1.2 (SDL 2.0.16, Python 3.8.13)
Hello from the pygame community. https://www.pygame.org/contribute.html
Traceback (most recent call last):
File "ilqr_jax_MPC.py", line 164, in
jac_l, hes_l, jac_l_final, hes_l_final, jac_f = derivative_init()
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/src/api.py", line 527, in cache_miss
out_flat = xla.xla_call(
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1937, in bind
return call_bind(self, fun, *args, **params)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1953, in call_bind
outs = top_trace.process_call(primitive, fun, tracers, params)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 687, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/dispatch.py", line 208, in _xla_call_impl
compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/linear_util.py", line 295, in memoized_fun
ans = call(fun, *args)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/dispatch.py", line 257, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/dispatch.py", line 302, in lower_xla_callable
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2188, in trace_to_jaxpr_final2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2139, in trace_to_subjaxpr_dynamic2
out_tracers = map(trace.full_raise, ans)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/util.py", line 47, in safe_map
return list(map(f, *args))
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 415, in full_raise
return self.pure(val)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1761, in new_const
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1163, in get_aval
return concrete_aval(x)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1155, in concrete_aval
raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Value <CompiledFunction of <function jacfwd..jacfun at 0x7fe3d73ac670>> with type <class 'jaxlib.xla_extension.CompiledFunction'> is not a valid JAX type
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "ilqr_jax_MPC.py", line 164, in
jac_l, hes_l, jac_l_final, hes_l_final, jac_f = derivative_init()
The jax is 0.3.16 and jaxlib is 0.3.15 with CUDA11.1 and cuDNN8.0.5, do you know how to fix these problemes?
Thank you very much and looking forward to your reply.
Hi @HouJun19, it seems that your jax and jaxlib version are too high. Could you take a look at this previous issue? https://github.com/YukunXia/Carla_iLQR_MPC/issues/5#issuecomment-992945428
Thanks for your reply!
Actually I have run this code successfully on jaxlib 0.1.47,but It turns out ti be very slow to complete one loop which needs more than 2 hours, and it has a warning that :Can not found GPU, fall back to CPU.
So I want to use GPU to accelerate the code, while I have tried many versions but all failed. For my GPU is NVIDIA 3060ti, some old version of jaxlib and CUDA are not support for the gpu, when I try the jaxlib 0.1.52 I have the same erro that " cannot import name 'pytree' from 'jaxlib' ". while using newer jaxlib and jax it also turns out some other erros like the origin question, I really don't know how to solve them. Sorry if I have some misunderstanding because I just start using python and still a lack of experience.
Looking forward to your reply.
Many thanks.
I ran my code on an old laptop CPU (i7-7700HQ) two years ago, so your machine should be good enough to handle it. I could easily achieve a 10Hz decision frequency and the speed bottleneck was from the algorithm stability or the Carla data quality side. My dynamical model assumes the time interval is 0.1 sec.
Maybe you could try some simpler Jax code on your CPU, and see if it achieves a reasonable speed.
For GPU, whether it accelerates or not depends on the program. My program has a shallow neural network, so you might see a bit of performance gain, but I guess not too much.
I should have made a docker environment so that the tests can be easier, but right now I haven't got enough time to fully revisit this project to upgrade the dependencies.
I didn't expect Jax to evolve this much :(
Thanks a lot! I will try some Jax code again.
But the first time I run this on jaxlib it seems every step needs more than 4 sec which is really wired.
I just run the
./CarlaUE4.sh
then
python ilqr_jax_MPC.py
I wonder if I have left something behind.T^T