pennylane
pennylane copied to clipboard
Fix gradient transform output shape for 2D QNode outputs
Fixes a bug in qml.gradients.param_shift that transposes the QNode output dimensions in the returned tensor:
https://github.com/PennyLaneAI/pennylane/blob/e1b15d27cc5a4a882b552c936580dd903fe3b988/pennylane/gradients/parameter_shift.py#L277
The following example demonstrates the issue:
dev = qml.device('default.qubit', wires=4)
@qml.qnode(dev)
def circuit(params):
qml.Rot(*params, wires=0)
return [qml.probs([0,1]), qml.probs([2,3])]
params = np.array([0.5, 0.5, 0.5], requires_grad=True)
qml.gradients.param_shift(circuit)(params).shape
>>> (4, 2, 3)
where the QNode output dimensions are (4, 2) instead of (2, 4).
~~Also fixes a bug in qml.gradients.param_shift_hessian [...]~~
Edit: 2nd fix has been moved to #2299
Hey @dime10, what's the status on this PR?
Hey @dime10, what's the status on this PR?
I've been writing some test cases to make sure everything is correct, but I'm still investing what the correct behaviour should be regarding the transposing that happens. Will update the PR soon!
Actually I think the issue considered here is closely tied to #2296, and the solution attempt does not adequately resolve the issue.
Using the following circuit to contrast computed gradients across diff_methods and interfaces:
@qml.qnode(dev, diff_method=m)
def circuit(params):
qml.Rot(*params, wires=0)
return [qml.probs([0, 1]), qml.probs([2, 3])]
Current state
https://github.com/PennyLaneAI/pennylane/blob/7b755831a4c042574a69b8eeffba46824a6a624b/pennylane/gradients/finite_difference.py#L385
We obtain the following table for the current state:
+-----------------+-----------+-----------+-----------+-----------+-----------+
| diff_method | transform | auto | jax | tf | torch |
+-----------------+-----------+-----------+-----------+-----------+-----------+
| backprop | None | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (8, 3) |
| parameter-shift | (4, 2, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
| finite-diff | (4, 2, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
+-----------------+-----------+-----------+-----------+-----------+-----------+
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| diff_method | transform | auto | jax | tf | torch |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| backprop | None | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[-0. -0.24 -0. ] |
| | | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] |
| | | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ] |
| | | | | | [-0. -0. -0. ] |
| | | [[ 0. -0. 0. ] | [[ 0. 0. 0. ] | [[ 0. -0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ]] |
| | | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | |
| parameter-shift | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] |
| | [ 0. 0. 0. ]] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0. 0. ] |
| | [[ 0. 0. 0. ] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] |
| | [ 0. 0. 0. ]] | | | | |
| | | [[ 0. 0. 0. ] | [[ 0. 0. 0. ] | [[ 0. 0. 0. ] | [[ 0. 0.24 0. ] |
| | [[ 0. 0.24 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ]] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] |
| | [[ 0. 0. 0. ] | | | | |
| | [ 0. 0. 0. ]]] | | | | |
| finite-diff | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] |
| | [-0. 0. -0. ]] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [-0. 0. -0. ] |
| | | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0. 0. ] |
| | [[ 0. 0. 0. ] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] |
| | [ 0. 0. 0. ]] | | | | |
| | | [[-0. 0. -0. ] | [[-0. 0. -0. ] | [[-0. 0. -0. ] | [[ 0. 0.24 0. ] |
| | [[ 0. 0.24 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ]] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] |
| | [[ 0. 0. 0. ] | | | | |
| | [ 0. 0. 0. ]]] | | | | |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
Regarding the shape, the issue of swapped output dimensions (4, 2, 3) when using gradient transforms directly is noticeable, as well as the merged output dimensions (8, 3) when using backprop with torch.
However, interestingly, the tensors produced with torch do match those produced by the transforms directly, which, if we take backprop as a reference, are incorrect.
No transpose
grads = qml.math.stack(grads)
return qml.math.moveaxis(grads, 0, -1)
Removing the transpose that is applied on the output of gradient transforms as shown above produces the following table:
+-----------------+-----------+-----------+-----------+-----------+-----------+
| diff_method | transform | auto | jax | tf | torch |
+-----------------+-----------+-----------+-----------+-----------+-----------+
| backprop | None | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (8, 3) |
| parameter-shift | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
| finite-diff | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
+-----------------+-----------+-----------+-----------+-----------+-----------+
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| diff_method | transform | auto | jax | tf | torch |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| backprop | None | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[-0. -0.24 -0. ] |
| | | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] |
| | | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ] |
| | | | | | [-0. -0. -0. ] |
| | | [[ 0. -0. 0. ] | [[ 0. 0. 0. ] | [[ 0. -0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ]] |
| | | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | |
| parameter-shift | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] |
| | [ 0. 0. 0. ] | [ 0. 0.24 0. ] | [ 0. 0. 0. ] | [ 0. 0.24 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0.24 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0.24 0. ] |
| | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] |
| | | | | | |
| | [[ 0. 0. 0. ] | [[ 0. 0. 0. ] | [[ 0. 0.24 0. ] | [[ 0. 0. 0. ] | [[ 0. 0. 0. ] |
| | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] |
| finite-diff | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] |
| | [ 0. 0. 0. ] | [ 0. 0.24 0. ] | [-0. 0. -0. ] | [ 0. 0.24 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0.24 0. ] | [-0. 0. -0. ] | [ 0. 0. 0. ] | [-0. 0. -0. ] | [ 0. 0.24 0. ] |
| | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] |
| | | | | | |
| | [[-0. 0. -0. ] | [[ 0. 0. 0. ] | [[ 0. 0.24 0. ] | [[ 0. 0. 0. ] | [[-0. 0. -0. ] |
| | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
The shape of the transforms is now correct, and the tensors in the transform and torch column do match the backprop reference tensor. However, while the shape of the remaining 3 interfaces was unaffected, their tensors are now incorrect.
Reshape
grads = qml.math.stack(grads)
new_shape = qml.math.shape(grads)[1:] + (qml.math.shape(grads)[0],)
return qml.math.reshape(qml.math.T(grads), new_shape)
Lastly, the current solution attempt in the PR, which contains the transpose plus a reshape does fix the shape but does nothing to address the incorrect tensors:
+-----------------+-----------+-----------+-----------+-----------+-----------+
| diff_method | transform | auto | jax | tf | torch |
+-----------------+-----------+-----------+-----------+-----------+-----------+
| backprop | None | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (8, 3) |
| parameter-shift | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
| finite-diff | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
+-----------------+-----------+-----------+-----------+-----------+-----------+
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| diff_method | transform | auto | jax | tf | torch |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| backprop | None | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[-0. -0.24 -0. ] |
| | | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] |
| | | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ] |
| | | | | | [-0. -0. -0. ] |
| | | [[ 0. -0. 0. ] | [[ 0. 0. 0. ] | [[ 0. -0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ]] |
| | | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | |
| parameter-shift | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] | [[[ 0. -0.24 0. ] |
| | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] |
| | | | | | |
| | [[ 0. 0.24 0. ] | [[ 0. 0. 0. ] | [[ 0. 0. 0. ] | [[ 0. 0. 0. ] | [[ 0. 0.24 0. ] |
| | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] |
| finite-diff | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] | [[[-0. -0.24 -0. ] |
| | [-0. 0. -0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [-0. 0. -0. ] |
| | [ 0. 0. 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0.24 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] | [ 0. 0. 0. ]] |
| | | | | | |
| | [[ 0. 0.24 0. ] | [[-0. 0. -0. ] | [[-0. 0. -0. ] | [[-0. 0. -0. ] | [[ 0. 0.24 0. ] |
| | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] | [ 0. 0. 0. ] |
| | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] | [ 0. 0. 0. ]]] |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+