catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

CUDA Quantum interpreter: Provide semantics for `qcond_p` when the condition is not a mid circuit measurement

Open erick-xanadu opened this issue 1 year ago • 0 comments

Context

Catalyst has recently added support for executing quantum programs in NVIDIA's CUDA Quantum platform. For example, in the following code, we see two identical quantum programs. Both programs will execute the RX gate with parameter a, and will return the state of the system.

import pennylane as qml
from catalyst import qjit
from catalyst.cuda import qjit as cjit, SoftwareQQPP

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=1))
def foo(a):
   qml.RX(a, wires=0)
   return qml.state()

@cjit
@qml.qnode(qml.device(SoftwareQQPP(wires=1)))
def bar(a):
   qml.RX(a, wires=0)
   return qml.state()

These equivalent quantum programs are running on different simulators. The first one, has been specified to run in the lightning.qubit simulator, while the second one has been specified to run in the qpp-cpu simulator

These equivalent quantum programs are written using the PennyLane's API. However, in order to execute these programs in the qpp-cpu simulator, we first need to translate them into NVIDIA's CUDA Quantum Python API to describe quantum programs. The program above written in NVIDIA's CUDA Quantum Python API could look like the following: import cudaq

def bar(a):
  kernel = cudaq.make_kernel()
  qreg = kernel.qalloc(1)
  qubit0 = qreg[0]
  kernel.rx(a, qubit0)
  return cudaq.get_state(kernel)

Goal

Support for translating quantum programs written in PennyLane's API into NVIDIA's CUDA Quantum Python API is limited at the moment. In particular, we don't have support for translating PennyLane's conditional statements in CUDA Quantum. Here is how one would express conditional statements and for loops in PennyLane.

@qjit()
@qml.qnode(qml.device("lightning.qubit", wires=6))
def circuit(pred: bool):
    @cond(pred)
    def conditional_flip():
        qml.PauliX(0)

    @conditional_flip.otherwise
    def conditional_flip(x):
         qml.Identity(0)

    conditional_flip()

    return measure(wires=0)

CUDA Quantum's Python API also allows users to specify conditionals. Please note that we are not going to be using CUDA Quantum's conditionals here. This is because CUDA Quantum's c_if operation requires a mid circuit measurement and does not generalize to other booleans. Assume, that there is no mid circuit measurement value.

Instead the above program should be translated to the following CUDA Quantum's Python API calls depending on the input pred. Either:

def circuit(pred):
  # pred is known to be True
  kernel = cudaq.make_kernel()
  qreg = kernel.qalloc(1)
  qubit0 = qreg[0]
  if pred:
    kernel.rx(a, qubit0)
  else:
    ...
  return cudaq.get_state(kernel)

Technical details

  • PennyLane's API calls are converted to CUDA Quantum's Python API via a custom JAX interpreter. found in catalyst.cuda.catalyst_to_cuda_interpreter.py.
  • You will need to implement the semantics for the Catalyst's JAX primitive cond_p.
  • Write a function that takes an InterprereterContext and a cond_p equation and checks for the parameter to cond_p.
  • Depending on the parameter, execute one branch or the other.
  • Assume the parameter to cond_p is not a mid circuit measurement.

erick-xanadu avatar Feb 16 '24 20:02 erick-xanadu