catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[BUG] Incorrect output pytree when using qml.counts() in specific output patterns

Open mehrdad2m opened this issue 1 year ago • 2 comments

Context

When using qml.counts() in the output of a quantum circuit with qjit, the output pytree is modified to replace the output pytree element related to qml.counts with tree_structure(("keys", "counts")). However this transformation is buggy and while it works for simple cases, it incorrectly transforms more complex patterns.

An example that works fine:

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return {"1":  qml.counts()}

result = circuit(0.5)
_, result_tree = tree_flatten(result)
print(result_tree)

The result is as expected:

PyTreeDef({'1': (*, *)})

In the following example, there are two patterns that result in the wrong output pytree:

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, result_tree = tree_flatten(result)
print(result_tree)

results in:

PyTreeDef(((*, *), {'2': *}))

instead of the expected pytree of:

PyTreeDef(({'1': (*, *)}, {'2': *}))
dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, result_tree = tree_flatten(result)
print(result_tree)

results in:

PyTreeDef(([{'1': *}, {'2': *}], (*, *)))

while the expected pytree is:

PyTreeDef(([{'1': *}, {'2': (*, *)}], {'3': *}))

A possible solution would update trace_quantum_measurements(), which is where the output pytree is modified. You could write a function replace_child_tree(tree, i, subtree) which receives a pytree and would replace the ith node of the tree that is visited in a DFS of subtree.

mehrdad2m avatar Aug 13 '24 22:08 mehrdad2m