catalyst
catalyst copied to clipboard
Catalyst does not support QJIT-compiling a parameterized circuit with `qml.FlipSign`
We discovered this issue when attempting to QJIT-compile a circuit implementing Grover's algorithm.
Consider the following PennyLane program that applies the qml.FlipSign operator:
import numpy as np
import pennylane as qml
NUM_QUBITS = 2
dev = qml.device("lightning.qubit", wires=NUM_QUBITS)
@qml.qnode(dev)
def circuit(basis_state):
wires = list(range(NUM_QUBITS))
qml.FlipSign(basis_state, wires=wires)
return qml.state()
basis_state = np.array([0., 0.])
state = circuit(basis_state)
As expected, the circuit flips the sign of the $|00\rangle$ basis state:
>>> print(state)
[-1.-0.j 0.+0.j 0.+0.j 0.+0.j]
When we attempt to QJIT-compile and execute this circuit, we get an error:
import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit
NUM_QUBITS = 2
dev = qml.device("lightning.qubit", wires=NUM_QUBITS)
@qjit
@qml.qnode(dev)
def circuit(basis_state):
wires = list(range(NUM_QUBITS))
qml.FlipSign(basis_state, wires=wires)
return qml.state()
basis_state = jnp.array([0., 0.])
state = circuit(basis_state)
Traceback (most recent call last):
...
File ".../venv/lib/python3.12/site-packages/catalyst/device/decomposition.py", line 82, in catalyst_decomposer
return op.decomposition()
^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/pennylane/operation.py", line 1337, in decomposition
return self.compute_decomposition(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/pennylane/templates/subroutines/flip_sign.py", line 144, in compute_decomposition
if arr_bin[-1] == 0:
^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 712, in __bool__
return self.aval._bool(self)
^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 1475, in error
raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
The error occurred in the FlipSign.compute_decomposition() method:
@staticmethod
def compute_decomposition(wires, arr_bin):
op_list = []
if arr_bin[-1] == 0:
op_list.append(qml.X(wires[-1]))
op_list.append(qml.ctrl(qml.Z(wires[-1]), control=wires[:-1], control_values=arr_bin[:-1]))
if arr_bin[-1] == 0:
op_list.append(qml.X(wires[-1]))
return op_list
The problem is in statements like if arr_bin[-1] == 0, where in the jitted case, arr_bin is a traced JAX array that is being used in Python control flow, which is not allowed.
Compiling the circuit with AutoGraph, @qjit(autograph=True), gives the same error, because AutoGraph is disabled by default for any module in PennyLane. To try to get around this issue, we followed the Adding modules for Autograph conversion docs and tried the following, which results in a different error:
import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit
NUM_QUBITS = 2
dev = qml.device("lightning.qubit", wires=NUM_QUBITS)
@qjit(autograph=True, autograph_include=["pennylane.templates.subroutines.flip_sign"])
@qml.qnode(dev)
def circuit(basis_state):
wires = list(range(NUM_QUBITS))
qml.FlipSign(basis_state, wires=wires)
return qml.state()
basis_state = jnp.array([0.0, 0.0])
state = circuit(basis_state)
Traceback (most recent call last):
...
File ".../venv/lib/python3.12/site-packages/catalyst/autograph/ag_primitives.py", line 579, in converted_call
return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/malt/impl/api.py", line 380, in converted_call
result = converted_f(*effective_args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/__autograph_generated_file_1_ucoey.py", line 35, in ag____call__
ag__.if_stmt(ag__.converted_call(ag__.ld(enabled), (), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File ".../venv/lib/python3.12/site-packages/catalyst/autograph/ag_primitives.py", line 132, in if_stmt
results = functional_cond()
^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 736, in __call__
return self._call_with_quantum_ctx(ctx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 662, in _call_with_quantum_ctx
_assert_cond_result_structure([s.out_tree() for s in out_sigs])
File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 1319, in _assert_cond_result_structure
raise TypeError(
TypeError: Conditional requires a consistent return structure across all branches! Got PyTreeDef((*, CustomNode(FlipSign[(Wires([0, 1]), (('n', (Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=3/1)>, Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=3/1)>)),))], []))) and PyTreeDef((*, *)).
The appropriate changes to Catalyst and/or PennyLane should be made to add support for the qml.FlipSign operator in QJIT-compiled circuits, where the basis-state input to qml.FlipSign is an input argument to the parameterized circuit.