catalyst
catalyst copied to clipboard
CUDA Quantum interpreter: Provide semantics for `qcond_p` when the condition is not a mid circuit measurement
Context
Catalyst has recently added support for executing quantum programs in NVIDIA's CUDA Quantum platform. For example, in the following code, we see two identical quantum programs. Both programs will execute the RX gate with parameter a, and will return the state of the system.
import pennylane as qml
from catalyst import qjit
from catalyst.cuda import qjit as cjit, SoftwareQQPP
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=1))
def foo(a):
qml.RX(a, wires=0)
return qml.state()
@cjit
@qml.qnode(qml.device(SoftwareQQPP(wires=1)))
def bar(a):
qml.RX(a, wires=0)
return qml.state()
These equivalent quantum programs are running on different simulators. The first one, has been specified to run in the lightning.qubit simulator, while the second one has been specified to run in the qpp-cpu simulator
These equivalent quantum programs are written using the PennyLane's API. However, in order to execute these programs in the qpp-cpu simulator, we first need to translate them into NVIDIA's CUDA Quantum Python API to describe quantum programs. The program above written in NVIDIA's CUDA Quantum Python API could look like the following: import cudaq
def bar(a):
kernel = cudaq.make_kernel()
qreg = kernel.qalloc(1)
qubit0 = qreg[0]
kernel.rx(a, qubit0)
return cudaq.get_state(kernel)
Goal
Support for translating quantum programs written in PennyLane's API into NVIDIA's CUDA Quantum Python API is limited at the moment. In particular, we don't have support for translating PennyLane's conditional statements in CUDA Quantum. Here is how one would express conditional statements and for loops in PennyLane.
@qjit()
@qml.qnode(qml.device("lightning.qubit", wires=6))
def circuit(pred: bool):
@cond(pred)
def conditional_flip():
qml.PauliX(0)
@conditional_flip.otherwise
def conditional_flip(x):
qml.Identity(0)
conditional_flip()
return measure(wires=0)
CUDA Quantum's Python API also allows users to specify conditionals. Please note that we are not going to be using CUDA Quantum's conditionals here. This is because CUDA Quantum's c_if operation requires a mid circuit measurement and does not generalize to other booleans. Assume, that there is no mid circuit measurement value.
Instead the above program should be translated to the following CUDA Quantum's Python API calls depending on the input pred. Either:
def circuit(pred):
# pred is known to be True
kernel = cudaq.make_kernel()
qreg = kernel.qalloc(1)
qubit0 = qreg[0]
if pred:
kernel.rx(a, qubit0)
else:
...
return cudaq.get_state(kernel)
Technical details
- PennyLane's API calls are converted to CUDA Quantum's Python API via a custom JAX interpreter. found in
catalyst.cuda.catalyst_to_cuda_interpreter.py. - You will need to implement the semantics for the Catalyst's JAX primitive
cond_p. - Write a function that takes an InterprereterContext and a
cond_pequation and checks for the parameter tocond_p. - Depending on the parameter, execute one branch or the other.
- Assume the parameter to
cond_pis not a mid circuit measurement.