catalyst
catalyst copied to clipboard
[Frontend] Ensure the order of `set_basis_state_p` and `set_state_p` are preserved.
Context: When a tape contains two state preparation operations, it will decompose one. However, the order of the decomposed operations is no longer guaranteed to come after the first state preparation. This is due to how the tracing interacts with transforms and decomposition. The only order that JAXPR cares about is the use-def chain. However, if two state preparation operations occur on a different subset of wires, then the state preparation can be placed after the one that was decomposed.
Description of the Change: We add set_basis_state_p and set_state_p to the FORCED_ORDER_PRIMITIVES set to ensure that the order is preserved as they are traced.
Benefits: No logic errors.
Possible Drawbacks: More dependency on the topological sorting. I believe some time ago it was found out that topological sorting takes a long time and it takes longer the more FORCED_ORDER_PRIMITIVES there are.
Related GitHub Issues:
TODO:
- [ ] MLIR Pass to analyze qnodes to make this a compile time guarantee.