jax icon indicating copy to clipboard operation
jax copied to clipboard

Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)>

Open Dharmendra04 opened this issue 1 year ago • 1 comments

Description

Traced<ShapedArray(float32[])>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=1/1)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7f5f9828e320>, in_tracers=(Traced<ShapedArray(float32[64,192,20]):JaxprTrace(level=1/1)>, Traced<ShapedArray(float32[64,192,20]):JaxprTrace(level=1/1)>, Traced<ShapedArray(float32[]):JaxprTrace(level=1/1)>), out_tracer_refs=[<weakref at 0x7f5f981186d0; to 'JaxprTracer' at 0x7f5f981188b0>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[64,192,20] b:f32[64,192,20] c:f32[]. let
    d:f32[64,192,20] = mul a b
    e:f32[] = reduce_sum[axes=(0, 1, 2)] d
    f:f32[] = div e c
  in (f,) }, 'in_shardings': (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue), 'out_shardings': (UnspecifiedValue,), 'resource_env': None, 'donated_invars': (False, False, False), 'name': '_reduce_max', 'keep_unused': False, 'inline': True}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f5f9811f170>, name_stack=NameStack(stack=(Transform(name='jvp'), Scope(name='NerfModel'))))) max_reflection_encoding

when I am printing the shape of a jax array, it is printing for the first batch, and then showing this message, what is this message exactly saying.

System info (python version, jaxlib version, accelerator, etc.)

python = 3.11 jaxlibversion = 0.4.23 accelarator = gpu

jax: 0.4.23 jaxlib: 0.4.23 numpy: 1.26.3 python: 3.11.7 | packaged by conda-forge | (main, Dec 23 2023, 14:43:09) [GCC 12.3.0] jax.devices (1 total, 1 local): [cuda(id=0)] process_count: 1

$ nvidia-smi Tue Feb 20 10:31:40 2024
+---------------------------------------------------------------------------------------+ | NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA GeForce RTX 4090 On | 00000000:01:00.0 On | Off | | 37% 57C P2 287W / 450W | 23126MiB / 24564MiB | 92% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | 0 N/A N/A 2812 G /usr/lib/xorg/Xorg 337MiB | | 0 N/A N/A 2980 G /usr/bin/gnome-shell 54MiB | | 0 N/A N/A 3619 G ...irefox/3728/usr/lib/firefox/firefox 30MiB | | 0 N/A N/A 237301 G colmap 24MiB | | 0 N/A N/A 553109 G /opt/teamviewer/tv_bin/TeamViewer 17MiB | | 0 N/A N/A 558062 C python 22216MiB | | 0 N/A N/A 560150 C python 386MiB | +---------------------------------------------------------------------------------------+

Dharmendra04 avatar Feb 20 '24 10:02 Dharmendra04

Hi - thanks for the question, and sorry for the unclear error message, but I think we'll need more information in order to help you. What you printed above looks like the normal repr of a traced object within an autodiff transformation; for example:

In [1]: import jax

In [2]: def f(x):
   ...:     print(x)  # this will print the tracer value within a grad transformation
   ...:     return jax.numpy.sin(x)
   ...: 

In [3]: jax.grad(f)(1.0)
Traced<ConcreteArray(1.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
  primal = 1.0
  tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[], weak_type=True), None)
    recipe = LambdaBinding()
Out[3]: Array(0.5403023, dtype=float32, weak_type=True)

Can you paste the code you were running (a minimal reproducible example if possible) and the full error traceback if applicable? Also, it helps to put code and tracebacks between triple tick marks (```) to format them as code. Thanks!

jakevdp avatar Feb 20 '24 13:02 jakevdp

Closing due to lack of activity here – feel free to comment here or open another issue if you're still running into problems!

jakevdp avatar Apr 10 '24 19:04 jakevdp