catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

Refactor CUDA qjit decorator to inherit from catalyst.QJIT

Open josh146 opened this issue 1 year ago • 3 comments

Context: With the change in #531, we can now modify the CUDA @qjit decorator to inherit from the main catalyst.qjit decorator. This enables the functionality provided by the main QJIT class to be supported by the CUDA quantum QJIT, while reducing code duplication.

Description of the Change:

  • Modifies QJIT_CUDA to inherit from catalyst.QJIT
  • catalyst.cuda.cudaqjit is rewritten to match the catalyst.qjit wrapper function implementation.
  • the CUDA interpret function is now redundant and can be removed.

Benefits:

  • Support for caching the JAXPR capture --- this happens 'just in time' (or ahead of time), and JAXPR capture is not repeated on every execution of the kernel.
  • autograph support is now enabled

Possible Drawbacks:

  • autograph support won't work end-to-end until CUDA Q support for for loops and if statements is added.
  • static_argnums support is not enabled, I'm not sure if this makes sense here, given the CUDA qjit is simply doing just-in-time capture, and not just-in-time compilation.
  • Similarly, other qjit keyword arguments such as pipeline, target, verbose, etc. likely also do not make sense/cannot be supported, and are hardcoded to be disabled.
  • The extract_backend_info should not be automatically run by the cuda compiler as it is catalyst-specific. We need to make this API a bit nicer for third-party compilers.

Related GitHub Issues: n/a

josh146 avatar Mar 15 '24 02:03 josh146

The only test not currently working is test_expval_2:

@cudaq_qjit
@qml.qnode(qml.device("softwareq.qpp", wires=2))
def circuit():
    qml.RY(jnp.pi / 4, wires=[1])
    return qml.expval(qml.PauliZ(1) + qml.PauliX(1))
See here for the traceback
>>> circuit()
[/usr/local/lib/python3.10/dist-packages/catalyst/jit.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    118             args = promote_arguments(self.c_sig, dynamic_args)
    119 
--> 120         return self.run(args, kwargs)
    121 
    122     def aot_compile(self):

[<ipython-input-15-41f6ffd49162>](https://localhost:8080/#) in run(self, args, kwargs)
     52         # # use catalyst.cuda.cuda_qjit to interpret the JAXPR and execute via CUDA Quantum
     53         ctx = InterpreterContext(self.jaxpr.jaxpr, self.jaxpr.literals, *args)
---> 54         results = interpret_impl(ctx, self.jaxpr.jaxpr)
     55 
     56         return tree_unflatten(self.out_treedef, results)

[/usr/local/lib/python3.10/dist-packages/catalyst/cuda/catalyst_to_cuda_interpreter.py](https://localhost:8080/#) in interpret_impl(ctx, jaxpr)
    785         # This is similar to direct-call threading
    786         # https://www.cs.toronto.edu/~matz/dissertation/matzDissertation-latex2html/node6.html
--> 787         INST_IMPL.get(eqn.primitive, default_impl)(ctx, eqn)
    788 
    789     retvals = _map(ctx.read, jaxpr.outvars)

[/usr/local/lib/python3.10/dist-packages/catalyst/cuda/catalyst_to_cuda_interpreter.py](https://localhost:8080/#) in change_hamiltonian(ctx, eqn)
    631     assert eqn.primitive == hamiltonian_p
    632 
--> 633     invals = _map(ctx.read, eqn.invars)
    634     coeffs = invals[0]
    635     terms = invals[1:]

[/usr/local/lib/python3.10/dist-packages/catalyst/cuda/catalyst_to_cuda_interpreter.py](https://localhost:8080/#) in _map(f, *collections)
    146 def _map(f, *collections):
    147     """Eager implementation of map."""
--> 148     return list(map(f, *collections))
    149 
    150 

[/usr/local/lib/python3.10/dist-packages/catalyst/cuda/catalyst_to_cuda_interpreter.py](https://localhost:8080/#) in read(self, var)
    210         if self.variable_map.get(var):
    211             var = self.variable_map[var]
--> 212         return self.env[var]
    213 
    214     def write(self, var, val):

KeyError: a

josh146 avatar Mar 15 '24 03:03 josh146

The errors I am seeing on CI are different :thinking:

 FAILED frontend/test/pytest/test_cuda_integration.py::TestCudaQ::test_qjit_cuda_remove_host_context - catalyst.utils.exceptions.CompileError: Cannot translate tapes with context.
FAILED frontend/test/pytest/test_cuda_integration.py::TestCudaQ::test_expval_2 - KeyError: a
FAILED frontend/test/pytest/test_cuda_integration.py::TestCudaQ::test_jit_capture - AssertionError: Expected 'capture' to not have been called. Called 1 times.
Calls: [call((Array([0.1, 0.2], dtype=float64),))].
FAILED frontend/test/pytest/test_cuda_integration.py::TestCudaQ::test_aot_capture - AssertionError: Expected 'capture' to have been called.

erick-xanadu avatar Mar 15 '24 13:03 erick-xanadu

The errors I am seeing on CI are different 🤔

I've just pushed a fix, the only failures should now be test_expval_2.

josh146 avatar Mar 15 '24 14:03 josh146

@josh146 any chance this was close to landing? Sounded like a decent improvement :D

dime10 avatar Jul 12 '24 14:07 dime10

@dime10 this was almost ready to go, there was just a single failing test that was beyond my expertise and I couldn't solve

josh146 avatar Jul 12 '24 15:07 josh146

made obsolete by #926

erick-xanadu avatar Jul 12 '24 18:07 erick-xanadu