catalyst
catalyst copied to clipboard
[draft] Split each tape into a different jaxpr during trace_quantum_function
Before submitting
TODO: changelog, tests, writeup of implementation
Please complete the following checklist when submitting a PR:
-
[ ] All new functions and code must be clearly commented and documented.
-
[ ] Ensure that code is properly formatted by running
make format. The latest version of black andclang-format-14are used in CI/CD to check formatting. -
[ ] All new features must include a unit test. Integration and frontend tests should be added to
frontend/test, Quantum dialect and MLIR tests should be added tomlir/test, and Runtime tests should be added toruntime/tests.
When all the above are checked, delete everything above the dashed line and fill in the pull request template.
Context: Split each tape into a different jaxpr during trace_quantum_function Returns an overall big jaxpr that calls these small "each_tape" jaxprs.
This is a preliminary design. There's still some problems, the most significant one being the tape order does not necessarily agree with the pre-transform qnode's return value order.
Also many pytests simply crash under the current design.
An alternative is to apply the transform before the tracing begins, circumventing the need to manually build jaxprs.
Description of the Change:
Benefits:
Possible Drawbacks:
Related GitHub Issues: closes #442 [sc-67125]