jax
jax copied to clipboard
Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)>
Description
Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)> with
primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/1)> with
pval = (ShapedArray(float32[]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f5f9828e320>, in_tracers=(Traced<ShapedArray(float32[64,192,20]):JaxprTrace(level=1/1)>, Traced<ShapedArray(float32[64,192,20]):JaxprTrace(level=1/1)>, Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>), out_tracer_refs=[<weakref at 0x7f5f981186d0; to 'JaxprTracer' at 0x7f5f981188b0>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[64,192,20] b:f32[64,192,20] c:f32[]. let
d:f32[64,192,20] = mul a b
e:f32[] = reduce_sum[axes=(0, 1, 2)] d
f:f32[] = div e c
in (f,) }, 'in_shardings': (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue), 'out_shardings': (UnspecifiedValue,), 'resource_env': None, 'donated_invars': (False, False, False), 'name': '_reduce_max', 'keep_unused': False, 'inline': True}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f5f9811f170>, name_stack=NameStack(stack=(Transform(name='jvp'), Scope(name='NerfModel'))))) max_reflection_encoding
when I am printing the shape of a jax array, it is printing for the first batch, and then showing this message, what is this message exactly saying.
System info (python version, jaxlib version, accelerator, etc.)
python = 3.11 jaxlibversion = 0.4.23 accelarator = gpu
jax: 0.4.23 jaxlib: 0.4.23 numpy: 1.26.3 python: 3.11.7 | packaged by conda-forge | (main, Dec 23 2023, 14:43:09) [GCC 12.3.0] jax.devices (1 total, 1 local): [cuda(id=0)] process_count: 1
$ nvidia-smi
Tue Feb 20 10:31:40 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 4090 On | 00000000:01:00.0 On | Off |
| 37% 57C P2 287W / 450W | 23126MiB / 24564MiB | 92% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | 0 N/A N/A 2812 G /usr/lib/xorg/Xorg 337MiB | | 0 N/A N/A 2980 G /usr/bin/gnome-shell 54MiB | | 0 N/A N/A 3619 G ...irefox/3728/usr/lib/firefox/firefox 30MiB | | 0 N/A N/A 237301 G colmap 24MiB | | 0 N/A N/A 553109 G /opt/teamviewer/tv_bin/TeamViewer 17MiB | | 0 N/A N/A 558062 C python 22216MiB | | 0 N/A N/A 560150 C python 386MiB | +---------------------------------------------------------------------------------------+
Hi - thanks for the question, and sorry for the unclear error message, but I think we'll need more information in order to help you. What you printed above looks like the normal repr
of a traced object within an autodiff transformation; for example:
In [1]: import jax
In [2]: def f(x):
...: print(x) # this will print the tracer value within a grad transformation
...: return jax.numpy.sin(x)
...:
In [3]: jax.grad(f)(1.0)
Traced<ConcreteArray(1.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
primal = 1.0
tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[], weak_type=True), None)
recipe = LambdaBinding()
Out[3]: Array(0.5403023, dtype=float32, weak_type=True)
Can you paste the code you were running (a minimal reproducible example if possible) and the full error traceback if applicable? Also, it helps to put code and tracebacks between triple tick marks (```
) to format them as code. Thanks!
Closing due to lack of activity here – feel free to comment here or open another issue if you're still running into problems!