catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

Compilation error for OQC device

Open devx5f opened this issue 2 months ago • 17 comments

Description: Using the OQC device is returning a compile error

Code Snippet:

dev = qml.device("oqc.cloud", backend="lucy", shots=2012, wires=1)

@qjit
@qml.qnode(dev)
def circuit(x: float):
    qml.RX(x, wires=0)
    return qml.counts(wires=[0])

Error:

CompileError: Device at /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/utils/../lib/librtd_oqc.so cannot be found!
Stack Trace
---------------------------------------------------------------------------
CompileError                              Traceback (most recent call last)
Cell In[5], line 3
      1 dev = qml.device("oqc.cloud", backend="lucy", shots=2012, wires=1)
----> 3 @qjit
      4 @qml.qnode(dev)
      5 def circuit(x: float):
      6     qml.RX(x, wires=0)
      7     return qml.counts(wires=[0, 1])

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jit.py:502, in qjit(fn, autograph, autograph_include, async_qnodes, target, keep_intermediate, verbose, logfile, pipelines, static_argnums, static_argnames, abstracted_axes, disable_assertions, seed, circuit_transform_pipeline, pass_plugins, dialect_plugins)
    499 if fn is None:
    500     return functools.partial(qjit, **kwargs)
--> 502 return QJIT(fn, CompileOptions(**kwargs))

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
     63 @wraps(func)
     64 def wrapper_exit(*args, **kwargs):
---> 65     output = func(*args, **kwargs)
     66     if lgr.isEnabledFor(log_level):  # pragma: no cover
     67         f_string = _get_bound_signature(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jit.py:565, in QJIT.__init__(self, fn, compile_options)
    563 # Static arguments require values, so we cannot AOT compile.
    564 if self.user_sig is not None and not self.compile_options.static_argnums:
--> 565     self.aot_compile()
    567 super().__init__("user_function")

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jit.py:618, in QJIT.aot_compile(self)
    616 # TODO: awkward, refactor or redesign the target feature
    617 if self.compile_options.target in ("jaxpr", "mlir", "binary"):
--> 618     self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
    619         self.user_sig or ()
    620     )
    622 if self.compile_options.target in ("mlir", "binary"):
    623     self.mlir_module = self.generate_ir()

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/debug/instruments.py:145, in instrument.<locals>.wrapper(*args, **kwargs)
    142 @functools.wraps(fn)
    143 def wrapper(*args, **kwargs):
    144     if not InstrumentSession.active:
--> 145         return fn(*args, **kwargs)
    147     with ResultReporter(stage_name, has_finegrained) as reporter:
    148         self = args[0]

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jit.py:759, in QJIT.capture(self, args, **kwargs)
    749     return QFunc.__call__(
    750         qnode,
    751         *args,
    752         **dict(params, **kwargs),
    753     )
    755 with Patcher(
    756     (qml.QNode, "__call__", closure),
    757 ):
    758     # TODO: improve PyTree handling
--> 759     jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
    760         self.user_function, static_argnums, abstracted_axes, full_sig, kwargs, dbg
    761     )
    762     self.compile_options.pass_plugins.update(plugins)
    763     self.compile_options.dialect_plugins.update(plugins)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jax_tracer.py:609, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs, debug_info)
    603     make_jaxpr_kwargs = {
    604         "static_argnums": static_argnums,
    605         "abstracted_axes": abstracted_axes,
    606         "debug_info": debug_info,
    607     }
    608     with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
--> 609         jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
    610         plugins = EvaluationContext.get_plugins()
    612 return jaxpr, out_type, out_treedef, plugins

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jax_extras/tracing.py:491, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
    489     f, out_tree_promise = flatten_fun(f, in_tree)
    490     f = annotate(f, in_type)
--> 491     jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
    492 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
    493 return closed_jaxpr, out_type, out_tree_promise()

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/jax/_src/profiler.py:334, in annotate_function.<locals>.wrapper(*args, **kwargs)
    331 @wraps(func)
    332 def wrapper(*args, **kwargs):
    333   with TraceAnnotation(name, **decorator_kwargs):
--> 334     return func(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2276, in trace_to_jaxpr_dynamic2(fun)
   2274 in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
   2275 with core.set_current_trace(trace):
-> 2276   ans = fun.call_wrapped(*in_tracers)
   2277 out_tracers = map(trace.to_jaxpr_tracer, ans)
   2278 jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/jax/_src/linear_util.py:211, in WrappedFun.call_wrapped(self, *args, **kwargs)
    209 def call_wrapped(self, *args, **kwargs):
    210   """Calls the transformed function"""
--> 211   return self.f_transformed(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/jax/_src/api_util.py:73, in flatten_fun(f, store, in_tree, *args_flat)
     69 @lu.transformation_with_aux2
     70 def flatten_fun(f: Callable, store: lu.Store,
     71                 in_tree: PyTreeDef, *args_flat):
     72   py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
---> 73   ans = f(*py_args, **py_kwargs)
     74   ans, out_tree = tree_flatten(ans)
     75   store.store(out_tree)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/jax/_src/linear_util.py:402, in _get_result_paths_thunk(_fun, _store, *args, **kwargs)
    400 @transformation_with_aux2
    401 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
--> 402   ans = _fun(*args, **kwargs)
    403   result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans))
    404   if _store:
    405     # In some instances a lu.WrappedFun is called multiple times, e.g.,
    406     # the bwd function in a custom_vjp

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jit.py:749, in QJIT.capture.<locals>.closure(qnode, *args, **kwargs)
    746 params["pass_pipeline"] = pass_pipeline
    747 params["debug_info"] = dbg
--> 749 return QFunc.__call__(
    750     qnode,
    751     *args,
    752     **dict(params, **kwargs),
    753 )

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/qfunc.py:145, in QFunc.__call__(self, *args, **kwargs)
    143 new_device = copy(self.device)
    144 new_device._shots = self._shots  # pylint: disable=protected-access
--> 145 qjit_device = QJITDevice(new_device)
    147 static_argnums = kwargs.pop("static_argnums", ())
    148 out_tree_expected = kwargs.pop("_out_tree_expected", [])

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
     63 @wraps(func)
     64 def wrapper_exit(*args, **kwargs):
---> 65     output = func(*args, **kwargs)
     66     if lgr.isEnabledFor(log_level):  # pragma: no cover
     67         f_string = _get_bound_signature(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/device/qjit_device.py:326, in QJITDevice.__init__(self, original_device)
    321     if _to_matrix_ops and not device_capabilities.supports_operation("QubitUnitary"):
    322         raise CompileError(
    323             "The device that specifies to_matrix_ops must support QubitUnitary."
    324         )
--> 326 backend = QJITDevice.extract_backend_info(original_device, device_capabilities)
    328 self.backend_name = backend.c_interface_name
    329 self.backend_lib = backend.lpath

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/device/qjit_device.py:300, in QJITDevice.extract_backend_info(device, capabilities)
    296 @staticmethod
    297 @debug_logger
    298 def extract_backend_info(device, capabilities: DeviceCapabilities) -> BackendInfo:
    299     """Wrapper around extract_backend_info in the runtime module."""
--> 300     return extract_backend_info(device, capabilities)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/device/qjit_device.py:175, in extract_backend_info(device, capabilities)
    172     raise CompileError(f"The {dname} device does not provide C interface for compilation.")
    174 if not pathlib.Path(device_lpath).is_file():
--> 175     raise CompileError(f"Device at {device_lpath} cannot be found!")
    177 if dname == "braket.local.qubit":  # pragma: no cover
    178     device_kwargs["device_type"] = dname

Version: 0.12.0

devx5f avatar Sep 03 '25 19:09 devx5f

Hi @devx5f, thanks for reporting this. I tried to replicate this locally but can't. I also checked the published wheels manually but they do contain the relevant library.

I wonder if this has something to do with anaconda, since I see it in the path of the error message. There is two things you could try that would be helpful for us:

  • try your example in a standard library venv environment (everything installed with pip), does the same problem occur?
  • search your installation directory (/opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst) for "librtd_oqc" files, is the shared lib present at all? maybe it's in an unexpected location

dime10 avatar Sep 03 '25 20:09 dime10

Hi, thank you for looking into this.

I tried it with a venv environment and it's still occurring in the new environment.

There's a libcustom_calls.so file in the utils directory and librtd_oqc.dylib and a oqc_python_module.so in the catalyst/lib/, but not seeing a specific librtd_oqc.so file in the catalyst package for the conda environment. I tried it with a 3.12 and 3.11 python version and my conda environment is 24.9.2 if that helps any

devx5f avatar Sep 03 '25 23:09 devx5f

Thanks @devx5f, that's really helpful. Could you confirm which OS (linux or mac) you are on, and which architecture (x86 or arm)?

dime10 avatar Sep 04 '25 19:09 dime10

I'm on a mac with the arm M4 Max chip

devx5f avatar Sep 04 '25 19:09 devx5f

Hi @devx5f, thanks for reporting this!

If librtd_oqc.dylib exists in your catalyst/lib/ (which is the correct behavior for mac users), then I think it might be a bug from our end that judges whether the suffix is .so or .dylib.

And after some searching indeed there it is! https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/third_party/oqc/oqc_device.py#L48

We will fix this soon. Thanks again for catching this 👍

paul0403 avatar Sep 04 '25 19:09 paul0403

okay great! Appreciate you looking into it

devx5f avatar Sep 04 '25 19:09 devx5f

Just want to confirm that this is actually the issue: can you change the above line locally on your machine to just have the .dylib suffix and see if it works?

It would be in

your_venv/lib/python3.10/site-packages/catalyst/third_party/oqc/oqc_device.py

line 48 in the current version

-        return "oqc", get_lib_path("oqc_runtime", "OQC_LIB_DIR") + "/librtd_oqc.so"
+        return "oqc", get_lib_path("oqc_runtime", "OQC_LIB_DIR") + "/librtd_oqc.dylib"

paul0403 avatar Sep 04 '25 19:09 paul0403

Yep, that fixed it.

Ran it with the code and I got the following error though:

RuntimeError: [/Users/runner/work/catalyst/catalyst/frontend/catalyst/third_party/oqc/src/../../../../../runtime/lib/backend/common/DynamicLibraryLoader.hpp:75][Function:getSymbol] Error in Catalyst Runtime: dlsym(0xebb42b30, counts): symbol not found

devx5f avatar Sep 04 '25 20:09 devx5f

Thanks for reporting! There's also another bug with the signature of the counts function... we will fix this soon!

paul0403 avatar Sep 04 '25 20:09 paul0403

great, thanks again for looking into it

devx5f avatar Sep 04 '25 21:09 devx5f

Hi guys, thank you for addressing the issues so promptly 🔥 I tested out the fixes on the dev version of the package and it looks like the issues above are fixed so thank you for those!


~~When running the circuit, it looks like the circuit is hanging indefinitely though with no error. I tried this with valid creds and invalid creds and got the same behavior with it hanging. We're able to run a circuit with the oqc-qcaas-client so can verify the creds are valid.~~

Tested on a venv environment and that's able to return an error:

RuntimeError: [/Users/runner/work/catalyst/catalyst/frontend/catalyst/third_party/oqc/src/./OQCRunner.hpp:48][Function:Counts] Error in Catalyst Runtime: No module named 'qcaas_client'
dev = qml.device("oqc.cloud", backend="lucy", shots=2012, wires=1)

@qjit
@qml.qnode(dev)
def circuit(x: float):
    qml.RX(x, wires=0)
    return qml.expval(qml.PauliZ(wires=1))

print(circuit(.77))
Stack trace ``` RuntimeError Traceback (most recent call last) Cell In[9], line 1 ----> 1 circuit(.77)

File ~/CS/test/test_bed_311/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func..wrapper_entry(*args, **kwargs) 54 s_caller = "::L".join( 55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]] 56 ) 57 lgr.debug( 58 f"Calling {f_string} from {s_caller}", 59 **_debug_log_kwargs, 60 ) ---> 61 return func(*args, **kwargs)

File ~/CS/test/test_bed_311/lib/python3.11/site-packages/catalyst/jit.py:602, in QJIT.call(self, *args, **kwargs) 599 dynamic_args = filter_static_args(args, self.compile_options.static_argnums) 600 args = promote_arguments(self.c_sig, dynamic_args) --> 602 return self.run(args, kwargs)

File ~/CS/test/test_bed_311/lib/python3.11/site-packages/catalyst/debug/instruments.py:145, in instrument..wrapper(*args, **kwargs) 142 @functools.wraps(fn) 143 def wrapper(*args, **kwargs): 144 if not InstrumentSession.active: --> 145 return fn(*args, **kwargs) 147 with ResultReporter(stage_name, has_finegrained) as reporter: 148 self = args[0]

File ~/CS/test/test_bed_311/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func..wrapper_entry(*args, **kwargs) 54 s_caller = "::L".join( 55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]] 56 ) 57 lgr.debug( 58 f"Calling {f_string} from {s_caller}", 59 **_debug_log_kwargs, 60 ) ---> 61 return func(*args, **kwargs)

File ~/CS/test/test_bed_311/lib/python3.11/site-packages/catalyst/jit.py:835, in QJIT.run(self, args, kwargs) 822 @instrument(has_finegrained=True) 823 @debug_logger 824 def run(self, args, kwargs): 825 """Invoke a previously compiled function with the supplied arguments. 826 827 Args: (...) 832 Any: results of the execution arranged into the original function's output PyTrees 833 """ --> 835 results = self.compiled_function(*args, **kwargs) 837 # TODO: Move this to the compiled function object. 838 return tree_unflatten(self.out_treedef, results)

File ~/CS/test/test_bed_311/lib/python3.11/site-packages/catalyst/compiled_functions.py:341, in CompiledFunction.call(self, *args, **kwargs) 337 abi_args, _buffer = self.args_to_memref_descs(self.restype, dynamic_args, **kwargs) 339 numpy_dict = {nparr.ctypes.data: nparr for nparr in _buffer} --> 341 result = CompiledFunction._exec( 342 self.shared_object, 343 self.restype, 344 self.out_type, 345 numpy_dict, 346 *abi_args, 347 ) 349 return result

File ~/CS/test/test_bed_311/lib/python3.11/site-packages/catalyst/compiled_functions.py:162, in CompiledFunction._exec(shared_object, has_return, out_type, numpy_dict, *args) 160 with shared_object as lib: 161 result_desc = type(args[0].contents) if has_return else None --> 162 retval = wrapper.wrap(lib.function, args, result_desc, lib.mem_transfer, numpy_dict) 164 if out_type is not None: 165 keep_outputs = [k for _, k in out_type]

RuntimeError: [/Users/runner/work/catalyst/catalyst/frontend/catalyst/third_party/oqc/src/./OQCRunner.hpp:48][Function:Counts] Error in Catalyst Runtime: No module named 'qcaas_client'

</details>

devx5f avatar Sep 09 '25 13:09 devx5f

Hi guys, I was able to get around the no module found by reverting to the nanobind package, when I attempted to train a circuit, I'm getting the following DifferentiableCompileError.

Code:

num_qubits = 8
num_layers = 2

weights = np.random.random(size=(1,))
dev = qml.device("oqc.cloud", backend="lucy", wires=num_qubits)

@qml.qjit()
@qml.qnode(dev, diff_method='parameter-shift')
def circuit(weights):
    qml.RY(weights, wires=0)
    return qml.expval(qml.PauliZ(0))

loss, grads = catalyst.value_and_grad(circuit)(weights)

Error:

DifferentiableCompileError: The parameter-shift method can only be used for QNodes which return either qml.expval or qml.probs.
Stack Trace:
DifferentiableCompileError                Traceback (most recent call last)
Cell In[4], line 1
----> 1 loss, grads = catalyst.value_and_grad(circuit)(weights)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:704, in GradCallable.__call__(self, *args, **kwargs)
    702 if self.grad_params.scalar_out:
    703     if self.grad_params.with_value:
--> 704         results = jax.value_and_grad(self.fn, argnums=argnums)(*args, **kwargs)
    705     else:
    706         results = jax.grad(self.fn, argnums=argnums)(*args, **kwargs)

    [... skipping hidden 16 frame]

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:596, in QJIT.__call__(self, *args, **kwargs)
    594     if self.jaxed_function is None:
    595         self.jaxed_function = JAX_QJIT(self)  # lazy gradient compilation
--> 596     return self.jaxed_function(*args, **kwargs)
    598 elif requires_promotion:
    599     dynamic_args = filter_static_args(args, self.compile_options.static_argnums)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:969, in JAX_QJIT.__call__(self, *args, **kwargs)
    967 @debug_logger
    968 def __call__(self, *args, **kwargs):
--> 969     return self.jaxed_function(*args, **kwargs)

    [... skipping hidden 9 frame]

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:944, in JAX_QJIT.compute_jvp(self, primals, tangents)
    942 results = self.wrap_callback(self.qjit_function, *primals)
    943 results_data, _results_shape = tree_flatten(results)
--> 944 derivatives = self.wrap_callback(self.get_derivative_qjit(argnums), *primals)
    945 derivatives_data, _derivatives_shape = tree_flatten(derivatives)
    947 jvps = [jnp.zeros_like(results_data[res_idx]) for res_idx in range(len(results_data))]

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:923, in JAX_QJIT.get_derivative_qjit(self, argnums)
    920 deriv_wrapper.__annotations__ = annotations
    921 deriv_wrapper.__signature__ = signature.replace(parameters=updated_params)
--> 923 self.derivative_functions[argnum_key] = QJIT(
    924     deriv_wrapper, self.qjit_function.compile_options
    925 )
    926 return self.derivative_functions[argnum_key]

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
     63 @wraps(func)
     64 def wrapper_exit(*args, **kwargs):
---> 65     output = func(*args, **kwargs)
     66     if lgr.isEnabledFor(log_level):  # pragma: no cover
     67         f_string = _get_bound_signature(*args, **kwargs)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:559, in QJIT.__init__(self, fn, compile_options)
    557 # Static arguments require values, so we cannot AOT compile.
    558 if self.user_sig is not None and not self.compile_options.static_argnums:
--> 559     self.aot_compile()
    561 super().__init__("user_function")

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:612, in QJIT.aot_compile(self)
    610 # TODO: awkward, refactor or redesign the target feature
    611 if self.compile_options.target in ("jaxpr", "mlir", "binary"):
--> 612     self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
    613         self.user_sig or ()
    614     )
    616 if self.compile_options.target in ("mlir", "binary"):
    617     self.mlir_module = self.generate_ir()

File ~/CS/test/test_bed/catalyst/frontend/catalyst/debug/instruments.py:145, in instrument.<locals>.wrapper(*args, **kwargs)
    142 @functools.wraps(fn)
    143 def wrapper(*args, **kwargs):
    144     if not InstrumentSession.active:
--> 145         return fn(*args, **kwargs)
    147     with ResultReporter(stage_name, has_finegrained) as reporter:
    148         self = args[0]

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:759, in QJIT.capture(self, args, **kwargs)
    749     return QFunc.__call__(
    750         qnode,
    751         *args,
    752         **dict(params, **kwargs),
    753     )
    755 with Patcher(
    756     (qml.QNode, "__call__", closure),
    757 ):
    758     # TODO: improve PyTree handling
--> 759     jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
    760         self.user_function, static_argnums, abstracted_axes, full_sig, kwargs, dbg
    761     )
    762     self.compile_options.pass_plugins.update(plugins)
    763     self.compile_options.dialect_plugins.update(plugins)

File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/jax_tracer.py:613, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs, debug_info)
    607     make_jaxpr_kwargs = {
    608         "static_argnums": static_argnums,
    609         "abstracted_axes": abstracted_axes,
    610         "debug_info": debug_info,
    611     }
    612     with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
--> 613         jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
    614         plugins = EvaluationContext.get_plugins()
    616 return jaxpr, out_type, out_treedef, plugins

File ~/CS/test/test_bed/catalyst/frontend/catalyst/jax_extras/tracing.py:483, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
    481     f, out_tree_promise = flatten_fun(f, in_tree)
    482     f = annotate(f, in_type)
--> 483     jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
    484 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
    485 return closed_jaxpr, out_type, out_tree_promise()

    [... skipping hidden 5 frame]

File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:917, in JAX_QJIT.get_derivative_qjit.<locals>.deriv_wrapper(*args, **kwargs)
    916 def deriv_wrapper(*args, **kwargs):
--> 917     return catalyst.jacobian(self.qjit_function, argnums=argnums)(*args, **kwargs)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:664, in GradCallable.__call__(self, *args, **kwargs)
    654 grad_params = _check_grad_params(
    655     self.grad_params.method,
    656     self.grad_params.scalar_out,
   (...)    661     self.grad_params.with_value,
    662 )
    663 input_data_flat, _ = tree_flatten((args, kwargs))
--> 664 jaxpr, out_tree = _make_jaxpr_check_differentiable(fn, grad_params, *args, **kwargs)
    665 if self.grad_params.with_value:  # use value_and_grad
    666     args_argnum = tuple(args[i] for i in grad_params.argnums)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:835, in _make_jaxpr_check_differentiable(f, grad_params, *args, **kwargs)
    829     if res.dtype.kind != "f":
    830         raise DifferentiableCompileError(
    831             "Catalyst.grad/jacobian only supports differentiation on floating-point "
    832             f"results, got '{res.dtype}' at position {pos}."
    833         )
--> 835 _verify_differentiable_child_qnodes(jaxpr, method)
    836 return jaxpr, out_tree

File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:862, in _verify_differentiable_child_qnodes(jaxpr, method)
    859             traverse_children(child_jaxpr)
    860             visited.add(py_callable)
--> 862 traverse_children(jaxpr)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:858, in _verify_differentiable_child_qnodes.<locals>.traverse_children(jaxpr)
    856 if py_callable not in visited:
    857     if isinstance(py_callable, QNode):
--> 858         _check_qnode_against_grad_method(py_callable, method, child_jaxpr)
    859     traverse_children(child_jaxpr)
    860     visited.add(py_callable)

File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:895, in _check_qnode_against_grad_method(f, method, jaxpr)
    887     raise DifferentiableCompileError(
    888         "Cannot differentiate a QNode explicitly marked non-differentiable (with "
    889         "diff_method=None)."
    890     )
    892 if f.diff_method == "parameter-shift" and any(
    893     prim not in [expval_p, probs_p] for prim in return_ops
    894 ):
--> 895     raise DifferentiableCompileError(
    896         "The parameter-shift method can only be used for QNodes "
    897         "which return either qml.expval or qml.probs."
    898     )
    900 if f.diff_method == "adjoint" and any(prim not in [expval_p] for prim in return_ops):
    901     raise DifferentiableCompileError(
    902         "The adjoint method can only be used for QNodes which return qml.expval."
    903     )

DifferentiableCompileError: The parameter-shift method can only be used for QNodes which return either qml.expval or qml.probs.

It's working okay with the lightning.qubit device and I was noticing that the return_op in the _check_qnode_against_grad_method method is turning into a dot_general instead of the expval when using the oqc device.

devx5f avatar Sep 14 '25 21:09 devx5f

Hi @devx5f , sorry for the late reply. Thanks for reporting the gradient error and the providing the stack trace. We will take a look soon.

paul0403 avatar Sep 25 '25 00:09 paul0403

Hi @devx5f , so it was discovered that there's essentially two issues:

  1. Various small connection issues with the plugin. These are being fixed in #2089 , and is scheduled to go into the next release (which is happening next week).
  • For the indefinite hanging, we discovered that the OQC cloud hangs when a qpu id is not provided. There was a bug in our connection. This is now fixed.
  • I could not reproduce the module-not-found error though. I suspect it's something with the conda environment, but I'm not sure. But regardless, I have updated the OQC runtime device's pybind to not start a new python interpreter, and just use the top level user's python interpreter instead. Hopefully this means the package pathing issues are solved 🙏
  1. Differentiability error The parameter shift method itself only works on circuits that return expectation values. This is just a property of the method itself. However since OQC device only returns raw counts. So when user circuit returns expectation value, Catalyst internally transforms the expval call into a raw counts call + postprocessing maths (the dot product op you're seeing). This means that differentiation method now sees raw counts in the circuit and is unhappy.

This actually has nothing to do with the OQC device. This is a restriction on all devices that do not natively support expval terminal measurements and thus need to run the measurement-from-samples/counts transform.

In principle, there's a simple fix: we can just swap the order of the measurement-from-counts transform and the parameter-shift transform, since then the parameter shift sees the original circuit with the expval op, creates the gradient circuit (also with expval op: it's essentially a duplicate of the original circuit with the parameters slightly different, hence the name), and transform the gradient circuit's expval op to counts op with measurement-from-counts. However, in practice there is an issue: these two transforms live in two completely different stages in the pennylane software stack, so moving them across these stage boundaries are not immediately doable. Essentially, this is a feature that is not yet supported/implemented 😭

There is a workaround, and it's just differentiating it with finite difference manually

weights = 0.1
dev = qml.device("oqc.cloud", backend="lucy", wires=8)

@qml.qjit
@qml.set_shots(10000)
@qml.qnode(dev, diff_method='parameter-shift')
def circuit(weights):
    qml.RY(weights, wires=0)
    return qml.expval(qml.PauliZ(0))

delta = 0.01
print((circuit(weights+delta) - circuit(weights)) / delta)

Note that since OQC device does not have analytical mode, you need to set the number of shots (the number of repeated circuit runs to generate the final statistics). When the user does not set it, a default of 1024 is used. Of course, more shots means more accurate results.

Hope this helps!

paul0403 avatar Oct 07 '25 15:10 paul0403

Hi @paul0403, I appreciate you going in depth for what's all going on and the misc fixes you added in. You guys 🪨! & thanks for suggesting the workaround, we'll try that route and see how things go.

The parameter shift change makes sense and goes into what I thought was originally happening, but understand how the implementation might not make that easy. I'll look forward to when it can be supported 👍

devx5f avatar Oct 08 '25 01:10 devx5f

Just want to add that your guy's support and communication with this library have been 🔥

devx5f avatar Oct 08 '25 01:10 devx5f

Love to hear your feedback!

Making existing transforms we have more device-agnostic is definitely very important. If the opportunity arises, we definitely look forward to meeting again on the road ahead ❤️

paul0403 avatar Oct 08 '25 02:10 paul0403