catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[Frontend][CUDAQ] Support `for_p` for CUDA Quantum Interpreter (#524)

Open zzzDavid opened this issue 1 year ago • 5 comments

Context

Catalyst supports CUDA Quantum (CUDAQ) Interpreter as a simulator backend (SoftwareQQPP) by having a custom Jaxpr interpreter that translates the input program to CUDA Quantum Python API calls. This PR adds the support for for_p for loop primitive on CUDAQ backend.

Goal

The goal is to translate catalyst programs with for loops into CUDAQ's Python API calls. For example, this is an input program:

@qjit()
@qml.qnode(qml.device(backend, wires=6))
def circuit(n: int):
    qml.Hadamard(wires=0)

    @for_loop(0, n - 1, 1)
    def loop_fn(i):
        qml.CNOT(wires=[i, i + 1])

    loop_fn()
    return qml.state()

It should be translated to the following CUDAQ API calls:

def circuit(n: int):

  kernel = cudaq.make_kernel()
  qreg = kernel.qalloc(6)
  qubit0 = qreg[0]
  kernel.h(qubit0)
  def loop(index):
    qubit_i = qreg[index]
    qubit_i_plus_1 = qreg[index + 1]
    kernel.cx(qubit_i, qubit_i_plus_1)
  kernel.for_loop(start=0, stop=n-1, function=loop)
  return cudaq.get_state(kernel)

Approach

Constructing a function that can be passed to cuda_q.for_loop, which interprets the loop body at a given iteration.

Description of the Change

  • Added a custom cudaq_for_p primitive in catalyst.cuda.catalyst.primitives.
  • Implemented a function that takes an InterpreterContext for a for_p equation and calls cudaq.for_loop with appropriate arguments.
  • Added test cases to check the CUDAQ backend result against the lightning backend's output.

Benefits:

  • Support @for_loop(start, end, step) for CUDAQ backend.

Possible Drawbacks:

  • Only support step=1 at the moment.
  • Does not support loop carried variable.
  • Does not support dynamic wires due to JAX trying to add a QuakeValue to either a constant or a JAX variable. E.g., running:
        from catalyst.cuda import SoftwareQQPP

        @qml.qnode(SoftwareQQPP(wires=6))
        def circuit(n: int):
            @for_loop(0, n - 1, 1)
            def loop_fn(i):
                qml.CNOT(wires=[i, i + 1]) # <-- notice i + 1

            loop_fn()
            return qml.state()

Related GitHub Issues: #524

zzzDavid avatar Feb 23 '24 19:02 zzzDavid

Hi @zzzDavid, thanks for opening this PR 🎉

Let us know once it's ready for review and we'll take a look!

dime10 avatar Feb 23 '24 19:02 dime10

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 99.55%. Comparing base (47fd3fc) to head (3211db8). Report is 19 commits behind head on main.

:exclamation: Current head 3211db8 differs from pull request most recent head db57879. Consider uploading reports for the commit db57879 to get more accurate results

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #541   +/-   ##
=======================================
  Coverage   99.55%   99.55%           
=======================================
  Files          52       52           
  Lines        8457     8484   +27     
  Branches      559      559           
=======================================
+ Hits         8419     8446   +27     
  Misses         20       20           
  Partials       18       18           

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Feb 23 '24 20:02 codecov[bot]

Hello @dime10, I have a version that passes all tests now, please take a look at my changes. Thanks!!

zzzDavid avatar Feb 23 '24 21:02 zzzDavid

Hi @erick-xanadu, with a working cudaq version I was able to further test my implementation, and I discovered some issues that I'd like to discuss. The issue is that a cudaq.for_loop API call only calls its loop body function once: cuda-quantum/python/cudaq/kernel/kernel_builder.py to build the MLIR IR for a WhileOp body.

          with InsertionPoint(bodyBlock):
                tmpIp = self.insertPoint
                self.insertPoint = InsertionPoint(bodyBlock)
                function(self.__createQuakeValue(bodyBlock.arguments[0]))
                self.insertPoint = tmpIp
                cc.ContinueOp(bodyBlock.arguments)

The issue is that, for an interpreter approach, I would need the loop body function to be called many times to evaluate each loop iteration. I can implement this with a for loop to call the loop body function, but I'm not sure that's a good idea because it's not using cudaq.for_loop any more. Do you have any suggestions?

zzzDavid avatar Mar 01 '24 15:03 zzzDavid

Hi Niansong,

I made the following changes to your branch:

diff --git a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
index bffdc8e9..3d6d8019 100644
--- a/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
+++ b/frontend/catalyst/cuda/catalyst_to_cuda_interpreter.py
@@ -723,10 +723,12 @@ def change_for(ctx, eqn):
 
         def interp_iter(self, iteration):  # pylint: disable=unused-argument
             """Called by cudaq.for_loop, interpret the loop body."""
+            new_elems = invals[3:]
+            new_elems[0] = iteration
+            _map(ctx.write, loop_body.invars, new_elems)
             res = interpret_impl(self.ctx, self.loop_body)
             self.outvars = res
 
-    _map(ctx.write, loop_body.invars, invals[3:])
     body_ctx = LoopContext(ctx, loop_body)
     cudaq_for(ctx.kernel, start, end, body_ctx.interp_iter)
     _map(ctx.write, eqn.outvars, body_ctx.outvars)

Something I didn't catch during review is that since CUDA Quantum is creating a new QuakeValue to represent the current index, it should be fed into the loop. This replaces the value of loop_body.invars at position 0.

This allows for programs that do not modify the loop variable to succeed. E.g.,

        @qml.qnode(qml.device("lightning.qubit", wires=6))
        def circuit_lightning(n: int):
            @for_loop(0, n - 1, 1)
            def loop_fn(i):
                qml.Hadamard(wires=[i])

            loop_fn()
            return qml.state()

        from catalyst.cuda import SoftwareQQPP

        @qml.qnode(SoftwareQQPP(wires=6))
        def circuit(n: int):
            @for_loop(0, n - 1, 1)
            def loop_fn(i):
                qml.Hadamard(wires=[i])

            loop_fn()
            return qml.state()

        cuda_compiled = catalyst.cuda.qjit(circuit)
        catalyst_compiled = qjit(circuit_lightning)
        expected = catalyst_compiled(4)
        observed = cuda_compiled(4)
        assert_allclose(expected, observed)

However, the test for dynamic wires fails due to JAX trying to add a QuakeValue to either a constant or a JAX variable. E.g., running:

        from catalyst.cuda import SoftwareQQPP

        @qml.qnode(SoftwareQQPP(wires=6))
        def circuit(n: int):
            @for_loop(0, n - 1, 1)
            def loop_fn(i):
                qml.CNOT(wires=[i, i + 1]) # <-- notice i + 1

            loop_fn()
            return qml.state()

produces the following error.

E         TypeError: Cannot interpret value of type <class 'cudaq._pycudaq.QuakeValue'> as an abstract array; it does not have a dtype attribute

I do wonder if there could be a way to trick JAX to perform at least some amount of computation when the other operand is a concrete value.

Let's just make the changes I suggest for now, or if you prefer to do the interpreted version (which would allow for more programs). I am happy with either for the time being. :)

Thanks @zzzDavid!

erick-xanadu avatar Mar 01 '24 23:03 erick-xanadu