pennylane icon indicating copy to clipboard operation
pennylane copied to clipboard

Fix gradient transform output shape for 2D QNode outputs

Open dime10 opened this issue 3 years ago • 3 comments
trafficstars

Fixes a bug in qml.gradients.param_shift that transposes the QNode output dimensions in the returned tensor:

https://github.com/PennyLaneAI/pennylane/blob/e1b15d27cc5a4a882b552c936580dd903fe3b988/pennylane/gradients/parameter_shift.py#L277

The following example demonstrates the issue:

dev = qml.device('default.qubit', wires=4)
@qml.qnode(dev)
def circuit(params):
    qml.Rot(*params, wires=0)
    return [qml.probs([0,1]), qml.probs([2,3])]

params = np.array([0.5, 0.5, 0.5], requires_grad=True)
qml.gradients.param_shift(circuit)(params).shape
>>> (4, 2, 3)

where the QNode output dimensions are (4, 2) instead of (2, 4).


~~Also fixes a bug in qml.gradients.param_shift_hessian [...]~~

Edit: 2nd fix has been moved to #2299

dime10 avatar Feb 24 '22 21:02 dime10

Hey @dime10, what's the status on this PR?

josh146 avatar Mar 06 '22 16:03 josh146

Hey @dime10, what's the status on this PR?

I've been writing some test cases to make sure everything is correct, but I'm still investing what the correct behaviour should be regarding the transposing that happens. Will update the PR soon!

dime10 avatar Mar 07 '22 19:03 dime10

Actually I think the issue considered here is closely tied to #2296, and the solution attempt does not adequately resolve the issue.

Using the following circuit to contrast computed gradients across diff_methods and interfaces:

@qml.qnode(dev, diff_method=m)
def circuit(params):
    qml.Rot(*params, wires=0)
    return [qml.probs([0, 1]), qml.probs([2, 3])]

Current state

https://github.com/PennyLaneAI/pennylane/blob/7b755831a4c042574a69b8eeffba46824a6a624b/pennylane/gradients/finite_difference.py#L385

We obtain the following table for the current state:

+-----------------+-----------+-----------+-----------+-----------+-----------+
| diff_method     | transform | auto      | jax       | tf        | torch     |
+-----------------+-----------+-----------+-----------+-----------+-----------+
| backprop        | None      | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (8, 3)    |
| parameter-shift | (4, 2, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
| finite-diff     | (4, 2, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
+-----------------+-----------+-----------+-----------+-----------+-----------+
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| diff_method     | transform               | auto                    | jax                     | tf                      | torch                   |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| backprop        | None                    | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[-0.   -0.24 -0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |  [ 0.    0.    0.  ]    |
|                 |                         |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |  [ 0.    0.24  0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |  [ 0.    0.    0.  ]    |
|                 |                         |                         |                         |                         |  [-0.   -0.   -0.  ]    |
|                 |                         |  [[ 0.   -0.    0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.   -0.    0.  ]   |  [ 0.    0.    0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |  [ 0.    0.    0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |  [ 0.    0.    0.  ]]   |
|                 |                         |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |                         |
| parameter-shift | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   |
|                 |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |                         |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.    0.  ]   |
|                 |  [[ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |
|                 |   [ 0.    0.    0.  ]]  |                         |                         |                         |                         |
|                 |                         |  [[ 0.    0.    0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.    0.24  0.  ]   |
|                 |  [[ 0.    0.24  0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |                         |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |
|                 |  [[ 0.    0.    0.  ]   |                         |                         |                         |                         |
|                 |   [ 0.    0.    0.  ]]] |                         |                         |                         |                         |
| finite-diff     | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   |
|                 |   [-0.    0.   -0.  ]]  |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [-0.    0.   -0.  ]   |
|                 |                         |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.    0.  ]   |
|                 |  [[ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |
|                 |   [ 0.    0.    0.  ]]  |                         |                         |                         |                         |
|                 |                         |  [[-0.    0.   -0.  ]   |  [[-0.    0.   -0.  ]   |  [[-0.    0.   -0.  ]   |  [[ 0.    0.24  0.  ]   |
|                 |  [[ 0.    0.24  0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |                         |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |
|                 |  [[ 0.    0.    0.  ]   |                         |                         |                         |                         |
|                 |   [ 0.    0.    0.  ]]] |                         |                         |                         |                         |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+

Regarding the shape, the issue of swapped output dimensions (4, 2, 3) when using gradient transforms directly is noticeable, as well as the merged output dimensions (8, 3) when using backprop with torch. However, interestingly, the tensors produced with torch do match those produced by the transforms directly, which, if we take backprop as a reference, are incorrect.

No transpose

        grads = qml.math.stack(grads)
        return qml.math.moveaxis(grads, 0, -1)

Removing the transpose that is applied on the output of gradient transforms as shown above produces the following table:

+-----------------+-----------+-----------+-----------+-----------+-----------+
| diff_method     | transform | auto      | jax       | tf        | torch     |
+-----------------+-----------+-----------+-----------+-----------+-----------+
| backprop        | None      | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (8, 3)    |
| parameter-shift | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
| finite-diff     | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
+-----------------+-----------+-----------+-----------+-----------+-----------+
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| diff_method     | transform               | auto                    | jax                     | tf                      | torch                   |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| backprop        | None                    | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[-0.   -0.24 -0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |  [ 0.    0.    0.  ]    |
|                 |                         |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |  [ 0.    0.24  0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |  [ 0.    0.    0.  ]    |
|                 |                         |                         |                         |                         |  [-0.   -0.   -0.  ]    |
|                 |                         |  [[ 0.   -0.    0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.   -0.    0.  ]   |  [ 0.    0.    0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |  [ 0.    0.    0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |  [ 0.    0.    0.  ]]   |
|                 |                         |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |                         |
| parameter-shift | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.24  0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.24  0.  ]   |
|                 |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |
|                 |                         |                         |                         |                         |                         |
|                 |  [[ 0.    0.    0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.    0.24  0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |
| finite-diff     | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.24  0.  ]   |   [-0.    0.   -0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.24  0.  ]   |   [-0.    0.   -0.  ]   |   [ 0.    0.    0.  ]   |   [-0.    0.   -0.  ]   |   [ 0.    0.24  0.  ]   |
|                 |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |
|                 |                         |                         |                         |                         |                         |
|                 |  [[-0.    0.   -0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.    0.24  0.  ]   |  [[ 0.    0.    0.  ]   |  [[-0.    0.   -0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+

The shape of the transforms is now correct, and the tensors in the transform and torch column do match the backprop reference tensor. However, while the shape of the remaining 3 interfaces was unaffected, their tensors are now incorrect.

Reshape

        grads = qml.math.stack(grads)
        new_shape = qml.math.shape(grads)[1:] + (qml.math.shape(grads)[0],)
        return qml.math.reshape(qml.math.T(grads), new_shape)

Lastly, the current solution attempt in the PR, which contains the transpose plus a reshape does fix the shape but does nothing to address the incorrect tensors:

+-----------------+-----------+-----------+-----------+-----------+-----------+
| diff_method     | transform | auto      | jax       | tf        | torch     |
+-----------------+-----------+-----------+-----------+-----------+-----------+
| backprop        | None      | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (8, 3)    |
| parameter-shift | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
| finite-diff     | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) | (2, 4, 3) |
+-----------------+-----------+-----------+-----------+-----------+-----------+
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| diff_method     | transform               | auto                    | jax                     | tf                      | torch                   |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+
| backprop        | None                    | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[-0.   -0.24 -0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |  [ 0.    0.    0.  ]    |
|                 |                         |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |  [ 0.    0.24  0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |  [ 0.    0.    0.  ]    |
|                 |                         |                         |                         |                         |  [-0.   -0.   -0.  ]    |
|                 |                         |  [[ 0.   -0.    0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.   -0.    0.  ]   |  [ 0.    0.    0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |  [ 0.    0.    0.  ]    |
|                 |                         |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |  [ 0.    0.    0.  ]]   |
|                 |                         |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |                         |
| parameter-shift | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   | [[[ 0.   -0.24  0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |
|                 |                         |                         |                         |                         |                         |
|                 |  [[ 0.    0.24  0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.    0.    0.  ]   |  [[ 0.    0.24  0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |
| finite-diff     | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   | [[[-0.   -0.24 -0.  ]   |
|                 |   [-0.    0.   -0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [-0.    0.   -0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.24  0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |   [ 0.    0.    0.  ]]  |
|                 |                         |                         |                         |                         |                         |
|                 |  [[ 0.    0.24  0.  ]   |  [[-0.    0.   -0.  ]   |  [[-0.    0.   -0.  ]   |  [[-0.    0.   -0.  ]   |  [[ 0.    0.24  0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |   [ 0.    0.    0.  ]   |
|                 |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |   [ 0.    0.    0.  ]]] |
+-----------------+-------------------------+-------------------------+-------------------------+-------------------------+-------------------------+

dime10 avatar Mar 10 '22 05:03 dime10