catalyst
catalyst copied to clipboard
[Frontend][CUDAQ] Support `for_p` for CUDA Quantum Interpreter (#524)
Context
Catalyst supports CUDA Quantum (CUDAQ) Interpreter as a simulator backend (SoftwareQQPP)
by having a custom Jaxpr interpreter that translates the input program to CUDA Quantum Python API calls. This PR adds the support for for_p
for loop primitive on CUDAQ backend.
Goal
The goal is to translate catalyst programs with for loops into CUDAQ's Python API calls. For example, this is an input program:
@qjit()
@qml.qnode(qml.device(backend, wires=6))
def circuit(n: int):
qml.Hadamard(wires=0)
@for_loop(0, n - 1, 1)
def loop_fn(i):
qml.CNOT(wires=[i, i + 1])
loop_fn()
return qml.state()
It should be translated to the following CUDAQ API calls:
def circuit(n: int):
kernel = cudaq.make_kernel()
qreg = kernel.qalloc(6)
qubit0 = qreg[0]
kernel.h(qubit0)
def loop(index):
qubit_i = qreg[index]
qubit_i_plus_1 = qreg[index + 1]
kernel.cx(qubit_i, qubit_i_plus_1)
kernel.for_loop(start=0, stop=n-1, function=loop)
return cudaq.get_state(kernel)
Approach
Constructing a function that can be passed to cuda_q.for_loop
, which interprets the loop body at a given iteration.
Description of the Change
- Added a custom
cudaq_for_p
primitive incatalyst.cuda.catalyst.primitives
. - Implemented a function that takes an
InterpreterContext
for afor_p
equation and callscudaq.for_loop
with appropriate arguments. - Added test cases to check the CUDAQ backend result against the lightning backend's output.
Benefits:
- Support
@for_loop(start, end, step)
for CUDAQ backend.
Possible Drawbacks:
- Only support
step=1
at the moment. - Does not support loop carried variable.
- Does not support dynamic wires due to JAX trying to add a QuakeValue to either a constant or a JAX variable. E.g., running:
from catalyst.cuda import SoftwareQQPP
@qml.qnode(SoftwareQQPP(wires=6))
def circuit(n: int):
@for_loop(0, n - 1, 1)
def loop_fn(i):
qml.CNOT(wires=[i, i + 1]) # <-- notice i + 1
loop_fn()
return qml.state()
Related GitHub Issues: #524
Hi @zzzDavid, thanks for opening this PR 🎉
Let us know once it's ready for review and we'll take a look!
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 99.55%. Comparing base (
47fd3fc
) to head (3211db8
). Report is 19 commits behind head on main.
:exclamation: Current head 3211db8 differs from pull request most recent head db57879. Consider uploading reports for the commit db57879 to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## main #541 +/- ##
=======================================
Coverage 99.55% 99.55%
=======================================
Files 52 52
Lines 8457 8484 +27
Branches 559 559
=======================================
+ Hits 8419 8446 +27
Misses 20 20
Partials 18 18
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Hello @dime10, I have a version that passes all tests now, please take a look at my changes. Thanks!!
Hi @erick-xanadu, with a working cudaq
version I was able to further test my implementation, and I discovered some issues that I'd like to discuss. The issue is that a cudaq.for_loop
API call only calls its loop body function
once: cuda-quantum/python/cudaq/kernel/kernel_builder.py to build the MLIR IR for a WhileOp
body.
with InsertionPoint(bodyBlock):
tmpIp = self.insertPoint
self.insertPoint = InsertionPoint(bodyBlock)
function(self.__createQuakeValue(bodyBlock.arguments[0]))
self.insertPoint = tmpIp
cc.ContinueOp(bodyBlock.arguments)
The issue is that, for an interpreter approach, I would need the loop body function to be called many times to evaluate each loop iteration. I can implement this with a for loop to call the loop body function, but I'm not sure that's a good idea because it's not using cudaq.for_loop
any more. Do you have any suggestions?
Hi Niansong,
I made the following changes to your branch:
diff --git a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
index bffdc8e9..3d6d8019 100644
--- a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
+++ b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
@@ -723,10 +723,12 @@ def change_for(ctx, eqn):
def interp_iter(self, iteration): # pylint: disable=unused-argument
"""Called by cudaq.for_loop, interpret the loop body."""
+ new_elems = invals[3:]
+ new_elems[0] = iteration
+ _map(ctx.write, loop_body.invars, new_elems)
res = interpret_impl(self.ctx, self.loop_body)
self.outvars = res
- _map(ctx.write, loop_body.invars, invals[3:])
body_ctx = LoopContext(ctx, loop_body)
cudaq_for(ctx.kernel, start, end, body_ctx.interp_iter)
_map(ctx.write, eqn.outvars, body_ctx.outvars)
Something I didn't catch during review is that since CUDA Quantum is creating a new QuakeValue to represent the current index, it should be fed into the loop. This replaces the value of loop_body.invars
at position 0.
This allows for programs that do not modify the loop variable to succeed. E.g.,
@qml.qnode(qml.device("lightning.qubit", wires=6))
def circuit_lightning(n: int):
@for_loop(0, n - 1, 1)
def loop_fn(i):
qml.Hadamard(wires=[i])
loop_fn()
return qml.state()
from catalyst.cuda import SoftwareQQPP
@qml.qnode(SoftwareQQPP(wires=6))
def circuit(n: int):
@for_loop(0, n - 1, 1)
def loop_fn(i):
qml.Hadamard(wires=[i])
loop_fn()
return qml.state()
cuda_compiled = catalyst.cuda.qjit(circuit)
catalyst_compiled = qjit(circuit_lightning)
expected = catalyst_compiled(4)
observed = cuda_compiled(4)
assert_allclose(expected, observed)
However, the test for dynamic wires fails due to JAX trying to add a QuakeValue to either a constant or a JAX variable. E.g., running:
from catalyst.cuda import SoftwareQQPP
@qml.qnode(SoftwareQQPP(wires=6))
def circuit(n: int):
@for_loop(0, n - 1, 1)
def loop_fn(i):
qml.CNOT(wires=[i, i + 1]) # <-- notice i + 1
loop_fn()
return qml.state()
produces the following error.
E TypeError: Cannot interpret value of type <class 'cudaq._pycudaq.QuakeValue'> as an abstract array; it does not have a dtype attribute
I do wonder if there could be a way to trick JAX to perform at least some amount of computation when the other operand is a concrete value.
Let's just make the changes I suggest for now, or if you prefer to do the interpreted version (which would allow for more programs). I am happy with either for the time being. :)
Thanks @zzzDavid!