catalyst
catalyst copied to clipboard
Refactor CUDA qjit decorator to inherit from catalyst.QJIT
Context: With the change in #531, we can now modify the CUDA @qjit decorator to inherit from the main catalyst.qjit decorator. This enables the functionality provided by the main QJIT class to be supported by the CUDA quantum QJIT, while reducing code duplication.
Description of the Change:
- Modifies
QJIT_CUDAto inherit fromcatalyst.QJIT catalyst.cuda.cudaqjitis rewritten to match thecatalyst.qjitwrapper function implementation.- the CUDA
interpretfunction is now redundant and can be removed.
Benefits:
- Support for caching the JAXPR capture --- this happens 'just in time' (or ahead of time), and JAXPR capture is not repeated on every execution of the kernel.
- autograph support is now enabled
Possible Drawbacks:
- autograph support won't work end-to-end until CUDA Q support for for loops and if statements is added.
static_argnumssupport is not enabled, I'm not sure if this makes sense here, given the CUDA qjit is simply doing just-in-time capture, and not just-in-time compilation.- Similarly, other qjit keyword arguments such as pipeline, target, verbose, etc. likely also do not make sense/cannot be supported, and are hardcoded to be disabled.
- The
extract_backend_infoshould not be automatically run by the cuda compiler as it is catalyst-specific. We need to make this API a bit nicer for third-party compilers.
Related GitHub Issues: n/a
The only test not currently working is test_expval_2:
@cudaq_qjit
@qml.qnode(qml.device("softwareq.qpp", wires=2))
def circuit():
qml.RY(jnp.pi / 4, wires=[1])
return qml.expval(qml.PauliZ(1) + qml.PauliX(1))
See here for the traceback
>>> circuit()
[/usr/local/lib/python3.10/dist-packages/catalyst/jit.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
118 args = promote_arguments(self.c_sig, dynamic_args)
119
--> 120 return self.run(args, kwargs)
121
122 def aot_compile(self):
[<ipython-input-15-41f6ffd49162>](https://localhost:8080/#) in run(self, args, kwargs)
52 # # use catalyst.cuda.cuda_qjit to interpret the JAXPR and execute via CUDA Quantum
53 ctx = InterpreterContext(self.jaxpr.jaxpr, self.jaxpr.literals, *args)
---> 54 results = interpret_impl(ctx, self.jaxpr.jaxpr)
55
56 return tree_unflatten(self.out_treedef, results)
[/usr/local/lib/python3.10/dist-packages/catalyst/cuda/catalyst_to_cuda_interpreter.py](https://localhost:8080/#) in interpret_impl(ctx, jaxpr)
785 # This is similar to direct-call threading
786 # https://www.cs.toronto.edu/~matz/dissertation/matzDissertation-latex2html/node6.html
--> 787 INST_IMPL.get(eqn.primitive, default_impl)(ctx, eqn)
788
789 retvals = _map(ctx.read, jaxpr.outvars)
[/usr/local/lib/python3.10/dist-packages/catalyst/cuda/catalyst_to_cuda_interpreter.py](https://localhost:8080/#) in change_hamiltonian(ctx, eqn)
631 assert eqn.primitive == hamiltonian_p
632
--> 633 invals = _map(ctx.read, eqn.invars)
634 coeffs = invals[0]
635 terms = invals[1:]
[/usr/local/lib/python3.10/dist-packages/catalyst/cuda/catalyst_to_cuda_interpreter.py](https://localhost:8080/#) in _map(f, *collections)
146 def _map(f, *collections):
147 """Eager implementation of map."""
--> 148 return list(map(f, *collections))
149
150
[/usr/local/lib/python3.10/dist-packages/catalyst/cuda/catalyst_to_cuda_interpreter.py](https://localhost:8080/#) in read(self, var)
210 if self.variable_map.get(var):
211 var = self.variable_map[var]
--> 212 return self.env[var]
213
214 def write(self, var, val):
KeyError: a
The errors I am seeing on CI are different :thinking:
FAILED frontend/test/pytest/test_cuda_integration.py::TestCudaQ::test_qjit_cuda_remove_host_context - catalyst.utils.exceptions.CompileError: Cannot translate tapes with context.
FAILED frontend/test/pytest/test_cuda_integration.py::TestCudaQ::test_expval_2 - KeyError: a
FAILED frontend/test/pytest/test_cuda_integration.py::TestCudaQ::test_jit_capture - AssertionError: Expected 'capture' to not have been called. Called 1 times.
Calls: [call((Array([0.1, 0.2], dtype=float64),))].
FAILED frontend/test/pytest/test_cuda_integration.py::TestCudaQ::test_aot_capture - AssertionError: Expected 'capture' to have been called.
The errors I am seeing on CI are different 🤔
I've just pushed a fix, the only failures should now be test_expval_2.
@josh146 any chance this was close to landing? Sounded like a decent improvement :D
@dime10 this was almost ready to go, there was just a single failing test that was beyond my expertise and I couldn't solve
made obsolete by #926