catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

`qml.sample()` fails with OutDBIdx shape canonicalization error in dynamic one-shot context

Open rniczh opened this issue 4 months ago • 1 comments

Context

When using @qjit with finite shots, qml.sample() measurements fail with a shape canonicalization error involving OutDBIdx references, while qml.expval() measurements work correctly. This occurs in the dynamic_one_shot transformation that's automatically applied when using finite shots and in the case of wires is not set to qml.qnode.

Reproduction

❌ Failing Case (qml.sample())

import pennylane as qml
from catalyst import qjit

backend = "lightning.qubit"

@qjit
@qml.qnode(qml.device(backend, shots=10), mcm_method='one-shot')
def circuit():
    qml.RX(0.0, wires=3)
    return qml.sample()

circuit()

✅ Working Case (qml.expval())

@qjit
@qml.qnode(qml.device(backend, shots=10), mcm_method='one-shot')
def circuit():
    qml.RX(0.0, wires=3)
    return qml.expval(qml.PauliZ(0))  # This works

circuit()

Why expval() works

  • No OutDBIdx references are created, the shape for expval(...) case is ShapedType which can be handled by jnp.zeros()

Full Stack Trace

Traceback (most recent call last):
  File "/path/to/test.py", line 7, in <module>
    @qjit
     ^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 502, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 565, in __init__
    self.aot_compile()
    ~~~~~~~~~~~~~~~~^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 618, in aot_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
                                                              ~~~~~~~~~~~~^
        self.user_sig or ()
        ^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/path/to/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 759, in capture
    jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
                                        ~~~~~~~~~~~~~~^
        self.user_function, static_argnums, abstracted_axes, full_sig, kwargs, dbg
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jax_tracer.py", line 613, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
                                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/catalyst/frontend/catalyst/jax_extras/tracing.py", line 499, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
                              ~~~~~~~~~~~~~~~~~~~~~~~^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
    return func(*args, **kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jit.py", line 749, in closure
    return QFunc.__call__(
           ~~~~~~~~~~~~~~^
        qnode,
        ^^^^^^
        *args,
        ^^^^^^
        **dict(params, **kwargs),
        ^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/qfunc.py", line 143, in __call__
    return Function(dynamic_one_shot(self, mcm_config=mcm_config))(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/jax_tracer.py", line 181, in __call__
    jaxpr, _, out_tree = make_jaxpr2(
                         ~~~~~~~~~~~~
        self.fn,
        ~~~~~~~~
        debug_info=kwargs.pop("debug_info", jdb("Function", self.fn, args, kwargs)),
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(*args, **kwargs)
    ~^^^^^^^^^^^^^^^^^
  File "/path/to/catalyst/frontend/catalyst/jax_extras/tracing.py", line 499, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
                              ~~~~~~~~~~~~~~~~~~~~~~~^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/profiler.py", line 354, in wrapper
    return func(*args, **kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/interpreters/partial_eval.py", line 2363, in trace_to_jaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 211, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/linear_util.py", line 396, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
  File "/path/to/catalyst/frontend/catalyst/qfunc.py", line 286, in one_shot_wrapper
    results = catalyst.vmap(wrap_single_shot_qnode)(arg_vmap)
  File "/path/to/catalyst/frontend/catalyst/api_extensions/function_maps.py", line 235, in __call__
    init_result_flat = [jnp.zeros(shape=shape.shape, dtype=shape.dtype) for shape, _ in shapes]
                        ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/numpy/array_creation.py", line 82, in zeros
    shape = canonicalize_shape(shape)
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/numpy/array_creation.py", line 45, in canonicalize_shape
    return core.canonicalize_shape(shape, context)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
  File "/path/to/venv/catalyst/lib/python3.13/site-packages/jax/_src/core.py", line 1864, in canonicalize_shape
    raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of integer scalars, got (1, OutDBIdx(val=0))

rniczh avatar Jul 30 '25 04:07 rniczh

We could do the eval_jaxpr to get the Traced<i64[1,Traced<int64[]>with<DynamicJaxprTrace>]>with<DynamicJaxprTrace> instead of OutDBIdx.

from jax.core import eval_jaxpr
fn_args_flat_for_eval = tree_flatten((fn_args, kwargs))[0]
res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *fn_args_flat_for_eval)
_, out_keep = unzip2(shapes)
res_flat = [r for r, k in zip(res_expanded, out_keep) if k]

Since the last dimension is dynamic, the original loop will run in an error. Could use the lax.scan to change to while loop to avoid that. However, it result in a shape we couldn't handle Traced<i64[6,1,Var(id=4592326400):int64[]]>with<DynamicJaxprTrace> (6 shots). We expect that last dimension should be Traced instead of Var since each result shape generated from fn should be the same.

def scan_fn(carry, i):
      fn_args_flat = args_flat
      for loc in batch_loc:
          ax = in_axes_flat[loc]
          fn_args_flat[loc] = jnp.take(args_flat[loc], i, axis=ax)

      fn_args = tree_unflatten(args_tree, fn_args_flat)
      res = self.fn(*fn_args, **kwargs)
      res_flat, _ = tree_flatten(res)

_, all_results = jax.lax.scan(scan_fn, None, jnp.arange(batch_size))

rniczh avatar Aug 07 '25 13:08 rniczh