catalyst
catalyst copied to clipboard
Add support for `qml.Snapshot`
PennyLane added the qml.Snapshot operation in Release 0.22.0, which saves the internal state of devices at arbitrary points of execution. For example:
import pennylane as qml
NUM_QUBITS = 2
dev = qml.device("default.qubit", wires=NUM_QUBITS)
@qml.qnode(dev)
def circuit():
wires = list(range(NUM_QUBITS))
qml.Snapshot("Initial state")
for wire in wires:
qml.Hadamard(wires=wire)
qml.Snapshot("After applying Hadamard gates")
return qml.probs()
results = qml.snapshots(circuit)()
for k, result in results.items():
print(f"{k:<30}: {result}")
Output:
Initial state : [1.+0.j 0.+0.j 0.+0.j 0.+0.j]
After applying Hadamard gates : [0.5+0.j 0.5+0.j 0.5+0.j 0.5+0.j]
execution_results : [0.25 0.25 0.25 0.25]
Adding Snapshot support to Catalyst would possibly be a nice feature to have for some users.
Note that currently if we try wrapping circuit() with a @qjit decorator, calling qml.snapshots(circuit)() will raise a TransformError exception since we cannot apply a PennyLane transform to a qjit-compiled function. This also requires changing the device to "lightning.qubit" in the example above.
Installation help Complete instructions to install Catalyst from source can be found here. Note that due to the size of the llvm-project it can take a while (~3 hrs on a personal laptop) to compile.
Some more context regarding a possible implementation:
Calling qml.snapshots currently raises a TransformError within PennyLane. To keep this issue focused, the goal is to support qml.Snapshot within Catalyst, without worrying about the qml.snapshots method for now.
A possible approach would be:
- Enable Catalyst to recognize qml.Snapshot by adding it to the list of supported runtime operations here.
- The output of the circuit can implicitly include a variable-length array of statevector snapshots, returned as part of the user function.
- Since qml.state already has a runtime implementation, it would be beneficial to dispatch qml.Snapshot efficiently by leveraging existing functionality instead of duplicating logic. A modified test would be:
NUM_QUBITS = 2
dev = qml.device("lightning.qubit", wires=NUM_QUBITS)
@qjit
@qml.qnode(dev)
def circuit():
wires = list(range(NUM_QUBITS))
qml.Snapshot("Initial state")
for wire in wires:
qml.Hadamard(wires=wire)
qml.Snapshot("After applying Hadamard gates")
return qml.probs()
results = circuit()
print(results)
A acceptable output (format is tentative) should include the following:
[array([1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j]), array([0.5+0.j, 0.5+0.j, 0.5+0.j, 0.5+0.j])], array([0.25, 0.25, 0.25, 0.25])
Breakdown of Expected Output: 1. First array: The statevector at the "Initial state" snapshot. 2. Second array: The statevector after Hadamard gates have been applied. 3. Final array: The execution results (qml.probs()), showing equal probabilities for all basis states.
Hi! I was trying to address this problem, but I think there's something that I am missing.
I am trying to save the "intermediate state" before applying the Hadamard and after, but it seems like I cannot use Quantum Tape to store that information and trying to "copy" the qml object and then apllying qml.state() is not working. I was also trying to refactor the the qjit library in such a way to make it compatible with the snapshot() that is already implemented. I have seen that I cannot use it because at low level qjit is using jaxlib extension Array Implementation and pennylane is using numpy.ndarray. I tried to do some cast inside the transform_dispatecher.py, but is not working properly.
Probably I am missing something, but for now it seems like: whatever enters the @qjit area is "ruined" and I cannot use anything else. So at this point my only idea is doing quine and using python regex.
So basically when I call snapshots I get the source code at run time, then extract the portion of the code in which I want to save the "intermediate state", runs it in a completely new instance and then get that partial state without ruining the main program.
If there is any qml data structure that can avoid all this please tell me something. I was reading the source code and the documentation, but it seems like: if there is something for qjit is not compatible with qnode and viceversa.
Regards
Hi @MaxBubblegum47 !
Thanks for working on this! I'll chime in with some points that you might find helpful.
So the idea of this transform is straightforward: produce a qml.state() result whenever a "snapshot bookmark" is encountered in the user script.
The big question is, should this transform happen in python, before reaching the mlir compilation stage of the catalyst stack; or, should it happen as a mlir transform pass?
Personally, I believe it should happen as a mlir transform pass, for two reasons:
qml.state()is what's known as a "terminal measurement" in pennylane, meaning that even if a transform somehow manages to insert one mid-circuit on the tape (I'm not sure if pennylane even allows you to do this), it will have no effect.- Doing this during the frontend python layer adds more preprocessing transforms before the compilation pipeline, which might not be ideal, especially if the same functionality could be achieved in the actual compilation stage. (3. Python is slower than cpp)
Since this is a brand new feature, implementing it end-to-end throughout the entire catalyst stack might be a bit too difficult in the allocated time frame. It probably makes more sense to implement this as two separate work items:
- In the frontend, implement a mechnism to compile user's
qml.Snapshot()mid-circuit bookmarks into mid-circuit bookmark operations in the compiled mlir. 2. In mlir, add a transform that replaces these bookmark operations byStateOp, and update the return values of the function correspondingly.
I propose that this issue only implement stage 2, since stage 1 involves learning a lot of jax mechanisms that might not be feasible in the allocated time window for you (although, you are of course welcomed to do so if you do find the time!).
To see what Catalyst mlir a frontend PennyLane python program would produce, you can use the keep_intermediate option in qjit:
device = qml.device("lightning.qubit", wires=1)
@qjit(keep_intermediate=True)
@qml.qnode(device)
def circuit():
qml.X(wires=0)
return qml.state()
print(circuit())
This will produce a directory called circuit (name is the same as your function name) in your working directory. In there you will find the mlir we produce at various points during compilation. You can inspect them to learn about the catalyst mlir, e.g. how do qubit values interact, or how do StateOp operations interact with qubit values to produce results; and you can even use them for tests!
We have a tutorial on how to write new mlir passes in catalyst, which you might wish to consult.
Note that without stage 1, the keep_intermediate will of course produce mlir without the midcircuit bookmarking operations. You can add new operations' definitions in QuantumOps.td, and add them manually to your test mlir. Notice that each operation has a specification for their assembly format that determines how they look like in textual mlir.
If you have any other questions, don't hesitate to reach out!
Hi @MaxBubblegum47, thank you for tackling this issue. A couple of points that can hopefully clear things up:
qml.Snapshotis implemented as a "regular" quantum operation in PennyLane, so it will appear on tape in the list of operations.- Catalyst, when encountering a QNode, will first create a tape from the function (just like PennyLane). If you breakpoint after this block you can see the Snapshot operation on tape.
- From the quantum tape, Catalyst converts all operations into JAX primitives (the JAXPR is the IR used at this level, and JAX primitives are the instructions). Primitives are created with
bindcalls, and this function is more or less a simple mapping from PennyLane's operations to equivalent JAX primitives. - Unfortunately, you cannot just place additional
qml.statecalls onto the tape, because PennyLane measurement processes always live in a separate section of the tape (tape.measurements), and are always assumed to apply at the end. However, thestate_pprimitive doesn't have this restriction and can appear anywhere in the JAXPR. - The last question then pertains to how to return the results of the Snapshot. The issue suggests producing additional function results, so collecting all the snapshot and then inserting them into the results of the quantum function will be fine, although if you see a better alternative, that's great too!
- Oh one more thing: You might notice that the
Snapshotoperation is no longer on the tape by the time you get to converting the tape to JAX primitives, this is because the operation is as of yet "unsupported", and thus removed by the operation decomposition system. Adding Snapshot onto the QJITDevice supported gate list should resolve this issue though.
Hope this helps and let me know if you have any follow up questions or ideas you want to discuss :)