catalyst
catalyst copied to clipboard
Compilation error for OQC device
Description: Using the OQC device is returning a compile error
Code Snippet:
dev = qml.device("oqc.cloud", backend="lucy", shots=2012, wires=1)
@qjit
@qml.qnode(dev)
def circuit(x: float):
qml.RX(x, wires=0)
return qml.counts(wires=[0])
Error:
CompileError: Device at /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/utils/../lib/librtd_oqc.so cannot be found!
Stack Trace
---------------------------------------------------------------------------
CompileError Traceback (most recent call last)
Cell In[5], line 3
1 dev = qml.device("oqc.cloud", backend="lucy", shots=2012, wires=1)
----> 3 @qjit
4 @qml.qnode(dev)
5 def circuit(x: float):
6 qml.RX(x, wires=0)
7 return qml.counts(wires=[0, 1])
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jit.py:502, in qjit(fn, autograph, autograph_include, async_qnodes, target, keep_intermediate, verbose, logfile, pipelines, static_argnums, static_argnames, abstracted_axes, disable_assertions, seed, circuit_transform_pipeline, pass_plugins, dialect_plugins)
499 if fn is None:
500 return functools.partial(qjit, **kwargs)
--> 502 return QJIT(fn, CompileOptions(**kwargs))
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
63 @wraps(func)
64 def wrapper_exit(*args, **kwargs):
---> 65 output = func(*args, **kwargs)
66 if lgr.isEnabledFor(log_level): # pragma: no cover
67 f_string = _get_bound_signature(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jit.py:565, in QJIT.__init__(self, fn, compile_options)
563 # Static arguments require values, so we cannot AOT compile.
564 if self.user_sig is not None and not self.compile_options.static_argnums:
--> 565 self.aot_compile()
567 super().__init__("user_function")
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jit.py:618, in QJIT.aot_compile(self)
616 # TODO: awkward, refactor or redesign the target feature
617 if self.compile_options.target in ("jaxpr", "mlir", "binary"):
--> 618 self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
619 self.user_sig or ()
620 )
622 if self.compile_options.target in ("mlir", "binary"):
623 self.mlir_module = self.generate_ir()
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/debug/instruments.py:145, in instrument.<locals>.wrapper(*args, **kwargs)
142 @functools.wraps(fn)
143 def wrapper(*args, **kwargs):
144 if not InstrumentSession.active:
--> 145 return fn(*args, **kwargs)
147 with ResultReporter(stage_name, has_finegrained) as reporter:
148 self = args[0]
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jit.py:759, in QJIT.capture(self, args, **kwargs)
749 return QFunc.__call__(
750 qnode,
751 *args,
752 **dict(params, **kwargs),
753 )
755 with Patcher(
756 (qml.QNode, "__call__", closure),
757 ):
758 # TODO: improve PyTree handling
--> 759 jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
760 self.user_function, static_argnums, abstracted_axes, full_sig, kwargs, dbg
761 )
762 self.compile_options.pass_plugins.update(plugins)
763 self.compile_options.dialect_plugins.update(plugins)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jax_tracer.py:609, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs, debug_info)
603 make_jaxpr_kwargs = {
604 "static_argnums": static_argnums,
605 "abstracted_axes": abstracted_axes,
606 "debug_info": debug_info,
607 }
608 with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
--> 609 jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
610 plugins = EvaluationContext.get_plugins()
612 return jaxpr, out_type, out_treedef, plugins
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jax_extras/tracing.py:491, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
489 f, out_tree_promise = flatten_fun(f, in_tree)
490 f = annotate(f, in_type)
--> 491 jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
492 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
493 return closed_jaxpr, out_type, out_tree_promise()
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/jax/_src/profiler.py:334, in annotate_function.<locals>.wrapper(*args, **kwargs)
331 @wraps(func)
332 def wrapper(*args, **kwargs):
333 with TraceAnnotation(name, **decorator_kwargs):
--> 334 return func(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2276, in trace_to_jaxpr_dynamic2(fun)
2274 in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
2275 with core.set_current_trace(trace):
-> 2276 ans = fun.call_wrapped(*in_tracers)
2277 out_tracers = map(trace.to_jaxpr_tracer, ans)
2278 jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/jax/_src/linear_util.py:211, in WrappedFun.call_wrapped(self, *args, **kwargs)
209 def call_wrapped(self, *args, **kwargs):
210 """Calls the transformed function"""
--> 211 return self.f_transformed(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/jax/_src/api_util.py:73, in flatten_fun(f, store, in_tree, *args_flat)
69 @lu.transformation_with_aux2
70 def flatten_fun(f: Callable, store: lu.Store,
71 in_tree: PyTreeDef, *args_flat):
72 py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
---> 73 ans = f(*py_args, **py_kwargs)
74 ans, out_tree = tree_flatten(ans)
75 store.store(out_tree)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/jax/_src/linear_util.py:402, in _get_result_paths_thunk(_fun, _store, *args, **kwargs)
400 @transformation_with_aux2
401 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
--> 402 ans = _fun(*args, **kwargs)
403 result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans))
404 if _store:
405 # In some instances a lu.WrappedFun is called multiple times, e.g.,
406 # the bwd function in a custom_vjp
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/jit.py:749, in QJIT.capture.<locals>.closure(qnode, *args, **kwargs)
746 params["pass_pipeline"] = pass_pipeline
747 params["debug_info"] = dbg
--> 749 return QFunc.__call__(
750 qnode,
751 *args,
752 **dict(params, **kwargs),
753 )
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/qfunc.py:145, in QFunc.__call__(self, *args, **kwargs)
143 new_device = copy(self.device)
144 new_device._shots = self._shots # pylint: disable=protected-access
--> 145 qjit_device = QJITDevice(new_device)
147 static_argnums = kwargs.pop("static_argnums", ())
148 out_tree_expected = kwargs.pop("_out_tree_expected", [])
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
63 @wraps(func)
64 def wrapper_exit(*args, **kwargs):
---> 65 output = func(*args, **kwargs)
66 if lgr.isEnabledFor(log_level): # pragma: no cover
67 f_string = _get_bound_signature(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/device/qjit_device.py:326, in QJITDevice.__init__(self, original_device)
321 if _to_matrix_ops and not device_capabilities.supports_operation("QubitUnitary"):
322 raise CompileError(
323 "The device that specifies to_matrix_ops must support QubitUnitary."
324 )
--> 326 backend = QJITDevice.extract_backend_info(original_device, device_capabilities)
328 self.backend_name = backend.c_interface_name
329 self.backend_lib = backend.lpath
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/device/qjit_device.py:300, in QJITDevice.extract_backend_info(device, capabilities)
296 @staticmethod
297 @debug_logger
298 def extract_backend_info(device, capabilities: DeviceCapabilities) -> BackendInfo:
299 """Wrapper around extract_backend_info in the runtime module."""
--> 300 return extract_backend_info(device, capabilities)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst/device/qjit_device.py:175, in extract_backend_info(device, capabilities)
172 raise CompileError(f"The {dname} device does not provide C interface for compilation.")
174 if not pathlib.Path(device_lpath).is_file():
--> 175 raise CompileError(f"Device at {device_lpath} cannot be found!")
177 if dname == "braket.local.qubit": # pragma: no cover
178 device_kwargs["device_type"] = dname
Version: 0.12.0
Hi @devx5f, thanks for reporting this. I tried to replicate this locally but can't. I also checked the published wheels manually but they do contain the relevant library.
I wonder if this has something to do with anaconda, since I see it in the path of the error message. There is two things you could try that would be helpful for us:
- try your example in a standard library venv environment (everything installed with pip), does the same problem occur?
- search your installation directory (
/opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/catalyst) for"librtd_oqc"files, is the shared lib present at all? maybe it's in an unexpected location
Hi, thank you for looking into this.
I tried it with a venv environment and it's still occurring in the new environment.
There's a libcustom_calls.so file in the utils directory and librtd_oqc.dylib and a oqc_python_module.so in the catalyst/lib/, but not seeing a specific librtd_oqc.so file in the catalyst package for the conda environment. I tried it with a 3.12 and 3.11 python version and my conda environment is 24.9.2 if that helps any
Thanks @devx5f, that's really helpful. Could you confirm which OS (linux or mac) you are on, and which architecture (x86 or arm)?
I'm on a mac with the arm M4 Max chip
Hi @devx5f, thanks for reporting this!
If librtd_oqc.dylib exists in your catalyst/lib/ (which is the correct behavior for mac users), then I think it might be a bug from our end that judges whether the suffix is .so or .dylib.
And after some searching indeed there it is! https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/third_party/oqc/oqc_device.py#L48
We will fix this soon. Thanks again for catching this 👍
okay great! Appreciate you looking into it
Just want to confirm that this is actually the issue: can you change the above line locally on your machine to just have the .dylib suffix and see if it works?
It would be in
your_venv/lib/python3.10/site-packages/catalyst/third_party/oqc/oqc_device.py
line 48 in the current version
- return "oqc", get_lib_path("oqc_runtime", "OQC_LIB_DIR") + "/librtd_oqc.so"
+ return "oqc", get_lib_path("oqc_runtime", "OQC_LIB_DIR") + "/librtd_oqc.dylib"
Yep, that fixed it.
Ran it with the code and I got the following error though:
RuntimeError: [/Users/runner/work/catalyst/catalyst/frontend/catalyst/third_party/oqc/src/../../../../../runtime/lib/backend/common/DynamicLibraryLoader.hpp:75][Function:getSymbol] Error in Catalyst Runtime: dlsym(0xebb42b30, counts): symbol not found
Thanks for reporting! There's also another bug with the signature of the counts function... we will fix this soon!
great, thanks again for looking into it
Hi guys, thank you for addressing the issues so promptly 🔥 I tested out the fixes on the dev version of the package and it looks like the issues above are fixed so thank you for those!
~~When running the circuit, it looks like the circuit is hanging indefinitely though with no error. I tried this with valid creds and invalid creds and got the same behavior with it hanging. We're able to run a circuit with the oqc-qcaas-client so can verify the creds are valid.~~
Tested on a venv environment and that's able to return an error:
RuntimeError: [/Users/runner/work/catalyst/catalyst/frontend/catalyst/third_party/oqc/src/./OQCRunner.hpp:48][Function:Counts] Error in Catalyst Runtime: No module named 'qcaas_client'
dev = qml.device("oqc.cloud", backend="lucy", shots=2012, wires=1)
@qjit
@qml.qnode(dev)
def circuit(x: float):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(wires=1))
print(circuit(.77))
Stack trace
``` RuntimeError Traceback (most recent call last) Cell In[9], line 1 ----> 1 circuit(.77)File ~/CS/test/test_bed_311/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.
File ~/CS/test/test_bed_311/lib/python3.11/site-packages/catalyst/jit.py:602, in QJIT.call(self, *args, **kwargs) 599 dynamic_args = filter_static_args(args, self.compile_options.static_argnums) 600 args = promote_arguments(self.c_sig, dynamic_args) --> 602 return self.run(args, kwargs)
File ~/CS/test/test_bed_311/lib/python3.11/site-packages/catalyst/debug/instruments.py:145, in instrument.
File ~/CS/test/test_bed_311/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.
File ~/CS/test/test_bed_311/lib/python3.11/site-packages/catalyst/jit.py:835, in QJIT.run(self, args, kwargs) 822 @instrument(has_finegrained=True) 823 @debug_logger 824 def run(self, args, kwargs): 825 """Invoke a previously compiled function with the supplied arguments. 826 827 Args: (...) 832 Any: results of the execution arranged into the original function's output PyTrees 833 """ --> 835 results = self.compiled_function(*args, **kwargs) 837 # TODO: Move this to the compiled function object. 838 return tree_unflatten(self.out_treedef, results)
File ~/CS/test/test_bed_311/lib/python3.11/site-packages/catalyst/compiled_functions.py:341, in CompiledFunction.call(self, *args, **kwargs) 337 abi_args, _buffer = self.args_to_memref_descs(self.restype, dynamic_args, **kwargs) 339 numpy_dict = {nparr.ctypes.data: nparr for nparr in _buffer} --> 341 result = CompiledFunction._exec( 342 self.shared_object, 343 self.restype, 344 self.out_type, 345 numpy_dict, 346 *abi_args, 347 ) 349 return result
File ~/CS/test/test_bed_311/lib/python3.11/site-packages/catalyst/compiled_functions.py:162, in CompiledFunction._exec(shared_object, has_return, out_type, numpy_dict, *args) 160 with shared_object as lib: 161 result_desc = type(args[0].contents) if has_return else None --> 162 retval = wrapper.wrap(lib.function, args, result_desc, lib.mem_transfer, numpy_dict) 164 if out_type is not None: 165 keep_outputs = [k for _, k in out_type]
RuntimeError: [/Users/runner/work/catalyst/catalyst/frontend/catalyst/third_party/oqc/src/./OQCRunner.hpp:48][Function:Counts] Error in Catalyst Runtime: No module named 'qcaas_client'
</details>
Hi guys, I was able to get around the no module found by reverting to the nanobind package, when I attempted to train a circuit, I'm getting the following DifferentiableCompileError.
Code:
num_qubits = 8
num_layers = 2
weights = np.random.random(size=(1,))
dev = qml.device("oqc.cloud", backend="lucy", wires=num_qubits)
@qml.qjit()
@qml.qnode(dev, diff_method='parameter-shift')
def circuit(weights):
qml.RY(weights, wires=0)
return qml.expval(qml.PauliZ(0))
loss, grads = catalyst.value_and_grad(circuit)(weights)
Error:
DifferentiableCompileError: The parameter-shift method can only be used for QNodes which return either qml.expval or qml.probs.
Stack Trace:
DifferentiableCompileError Traceback (most recent call last)
Cell In[4], line 1
----> 1 loss, grads = catalyst.value_and_grad(circuit)(weights)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:704, in GradCallable.__call__(self, *args, **kwargs)
702 if self.grad_params.scalar_out:
703 if self.grad_params.with_value:
--> 704 results = jax.value_and_grad(self.fn, argnums=argnums)(*args, **kwargs)
705 else:
706 results = jax.grad(self.fn, argnums=argnums)(*args, **kwargs)
[... skipping hidden 16 frame]
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:596, in QJIT.__call__(self, *args, **kwargs)
594 if self.jaxed_function is None:
595 self.jaxed_function = JAX_QJIT(self) # lazy gradient compilation
--> 596 return self.jaxed_function(*args, **kwargs)
598 elif requires_promotion:
599 dynamic_args = filter_static_args(args, self.compile_options.static_argnums)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:969, in JAX_QJIT.__call__(self, *args, **kwargs)
967 @debug_logger
968 def __call__(self, *args, **kwargs):
--> 969 return self.jaxed_function(*args, **kwargs)
[... skipping hidden 9 frame]
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:944, in JAX_QJIT.compute_jvp(self, primals, tangents)
942 results = self.wrap_callback(self.qjit_function, *primals)
943 results_data, _results_shape = tree_flatten(results)
--> 944 derivatives = self.wrap_callback(self.get_derivative_qjit(argnums), *primals)
945 derivatives_data, _derivatives_shape = tree_flatten(derivatives)
947 jvps = [jnp.zeros_like(results_data[res_idx]) for res_idx in range(len(results_data))]
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:923, in JAX_QJIT.get_derivative_qjit(self, argnums)
920 deriv_wrapper.__annotations__ = annotations
921 deriv_wrapper.__signature__ = signature.replace(parameters=updated_params)
--> 923 self.derivative_functions[argnum_key] = QJIT(
924 deriv_wrapper, self.qjit_function.compile_options
925 )
926 return self.derivative_functions[argnum_key]
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
63 @wraps(func)
64 def wrapper_exit(*args, **kwargs):
---> 65 output = func(*args, **kwargs)
66 if lgr.isEnabledFor(log_level): # pragma: no cover
67 f_string = _get_bound_signature(*args, **kwargs)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:559, in QJIT.__init__(self, fn, compile_options)
557 # Static arguments require values, so we cannot AOT compile.
558 if self.user_sig is not None and not self.compile_options.static_argnums:
--> 559 self.aot_compile()
561 super().__init__("user_function")
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:612, in QJIT.aot_compile(self)
610 # TODO: awkward, refactor or redesign the target feature
611 if self.compile_options.target in ("jaxpr", "mlir", "binary"):
--> 612 self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
613 self.user_sig or ()
614 )
616 if self.compile_options.target in ("mlir", "binary"):
617 self.mlir_module = self.generate_ir()
File ~/CS/test/test_bed/catalyst/frontend/catalyst/debug/instruments.py:145, in instrument.<locals>.wrapper(*args, **kwargs)
142 @functools.wraps(fn)
143 def wrapper(*args, **kwargs):
144 if not InstrumentSession.active:
--> 145 return fn(*args, **kwargs)
147 with ResultReporter(stage_name, has_finegrained) as reporter:
148 self = args[0]
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:759, in QJIT.capture(self, args, **kwargs)
749 return QFunc.__call__(
750 qnode,
751 *args,
752 **dict(params, **kwargs),
753 )
755 with Patcher(
756 (qml.QNode, "__call__", closure),
757 ):
758 # TODO: improve PyTree handling
--> 759 jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
760 self.user_function, static_argnums, abstracted_axes, full_sig, kwargs, dbg
761 )
762 self.compile_options.pass_plugins.update(plugins)
763 self.compile_options.dialect_plugins.update(plugins)
File /opt/anaconda3/envs/test_bed/lib/python3.12/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/jax_tracer.py:613, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs, debug_info)
607 make_jaxpr_kwargs = {
608 "static_argnums": static_argnums,
609 "abstracted_axes": abstracted_axes,
610 "debug_info": debug_info,
611 }
612 with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
--> 613 jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
614 plugins = EvaluationContext.get_plugins()
616 return jaxpr, out_type, out_treedef, plugins
File ~/CS/test/test_bed/catalyst/frontend/catalyst/jax_extras/tracing.py:483, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
481 f, out_tree_promise = flatten_fun(f, in_tree)
482 f = annotate(f, in_type)
--> 483 jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
484 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
485 return closed_jaxpr, out_type, out_tree_promise()
[... skipping hidden 5 frame]
File ~/CS/test/test_bed/catalyst/frontend/catalyst/jit.py:917, in JAX_QJIT.get_derivative_qjit.<locals>.deriv_wrapper(*args, **kwargs)
916 def deriv_wrapper(*args, **kwargs):
--> 917 return catalyst.jacobian(self.qjit_function, argnums=argnums)(*args, **kwargs)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:664, in GradCallable.__call__(self, *args, **kwargs)
654 grad_params = _check_grad_params(
655 self.grad_params.method,
656 self.grad_params.scalar_out,
(...) 661 self.grad_params.with_value,
662 )
663 input_data_flat, _ = tree_flatten((args, kwargs))
--> 664 jaxpr, out_tree = _make_jaxpr_check_differentiable(fn, grad_params, *args, **kwargs)
665 if self.grad_params.with_value: # use value_and_grad
666 args_argnum = tuple(args[i] for i in grad_params.argnums)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:835, in _make_jaxpr_check_differentiable(f, grad_params, *args, **kwargs)
829 if res.dtype.kind != "f":
830 raise DifferentiableCompileError(
831 "Catalyst.grad/jacobian only supports differentiation on floating-point "
832 f"results, got '{res.dtype}' at position {pos}."
833 )
--> 835 _verify_differentiable_child_qnodes(jaxpr, method)
836 return jaxpr, out_tree
File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:862, in _verify_differentiable_child_qnodes(jaxpr, method)
859 traverse_children(child_jaxpr)
860 visited.add(py_callable)
--> 862 traverse_children(jaxpr)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:858, in _verify_differentiable_child_qnodes.<locals>.traverse_children(jaxpr)
856 if py_callable not in visited:
857 if isinstance(py_callable, QNode):
--> 858 _check_qnode_against_grad_method(py_callable, method, child_jaxpr)
859 traverse_children(child_jaxpr)
860 visited.add(py_callable)
File ~/CS/test/test_bed/catalyst/frontend/catalyst/api_extensions/differentiation.py:895, in _check_qnode_against_grad_method(f, method, jaxpr)
887 raise DifferentiableCompileError(
888 "Cannot differentiate a QNode explicitly marked non-differentiable (with "
889 "diff_method=None)."
890 )
892 if f.diff_method == "parameter-shift" and any(
893 prim not in [expval_p, probs_p] for prim in return_ops
894 ):
--> 895 raise DifferentiableCompileError(
896 "The parameter-shift method can only be used for QNodes "
897 "which return either qml.expval or qml.probs."
898 )
900 if f.diff_method == "adjoint" and any(prim not in [expval_p] for prim in return_ops):
901 raise DifferentiableCompileError(
902 "The adjoint method can only be used for QNodes which return qml.expval."
903 )
DifferentiableCompileError: The parameter-shift method can only be used for QNodes which return either qml.expval or qml.probs.
It's working okay with the lightning.qubit device and I was noticing that the return_op in the _check_qnode_against_grad_method method is turning into a dot_general instead of the expval when using the oqc device.
Hi @devx5f , sorry for the late reply. Thanks for reporting the gradient error and the providing the stack trace. We will take a look soon.
Hi @devx5f , so it was discovered that there's essentially two issues:
- Various small connection issues with the plugin. These are being fixed in #2089 , and is scheduled to go into the next release (which is happening next week).
- For the indefinite hanging, we discovered that the OQC cloud hangs when a qpu id is not provided. There was a bug in our connection. This is now fixed.
- I could not reproduce the module-not-found error though. I suspect it's something with the conda environment, but I'm not sure. But regardless, I have updated the OQC runtime device's pybind to not start a new python interpreter, and just use the top level user's python interpreter instead. Hopefully this means the package pathing issues are solved 🙏
- Differentiability error The parameter shift method itself only works on circuits that return expectation values. This is just a property of the method itself. However since OQC device only returns raw counts. So when user circuit returns expectation value, Catalyst internally transforms the expval call into a raw counts call + postprocessing maths (the dot product op you're seeing). This means that differentiation method now sees raw counts in the circuit and is unhappy.
This actually has nothing to do with the OQC device. This is a restriction on all devices that do not natively support expval terminal measurements and thus need to run the measurement-from-samples/counts transform.
In principle, there's a simple fix: we can just swap the order of the measurement-from-counts transform and the parameter-shift transform, since then the parameter shift sees the original circuit with the expval op, creates the gradient circuit (also with expval op: it's essentially a duplicate of the original circuit with the parameters slightly different, hence the name), and transform the gradient circuit's expval op to counts op with measurement-from-counts. However, in practice there is an issue: these two transforms live in two completely different stages in the pennylane software stack, so moving them across these stage boundaries are not immediately doable. Essentially, this is a feature that is not yet supported/implemented 😭
There is a workaround, and it's just differentiating it with finite difference manually
weights = 0.1
dev = qml.device("oqc.cloud", backend="lucy", wires=8)
@qml.qjit
@qml.set_shots(10000)
@qml.qnode(dev, diff_method='parameter-shift')
def circuit(weights):
qml.RY(weights, wires=0)
return qml.expval(qml.PauliZ(0))
delta = 0.01
print((circuit(weights+delta) - circuit(weights)) / delta)
Note that since OQC device does not have analytical mode, you need to set the number of shots (the number of repeated circuit runs to generate the final statistics). When the user does not set it, a default of 1024 is used. Of course, more shots means more accurate results.
Hope this helps!
Hi @paul0403, I appreciate you going in depth for what's all going on and the misc fixes you added in. You guys 🪨! & thanks for suggesting the workaround, we'll try that route and see how things go.
The parameter shift change makes sense and goes into what I thought was originally happening, but understand how the implementation might not make that easy. I'll look forward to when it can be supported 👍
Just want to add that your guy's support and communication with this library have been 🔥
Love to hear your feedback!
Making existing transforms we have more device-agnostic is definitely very important. If the opportunity arises, we definitely look forward to meeting again on the road ahead ❤️