pennylane
pennylane copied to clipboard
Gradients of non-linear post-processing functions
Context We have been stumbling over this problem during the classical shadows PRs https://github.com/PennyLaneAI/pennylane/pull/2871 as well as the counts https://github.com/PennyLaneAI/pennylane/issues/2932 . I believe it is also related to this older issue https://github.com/PennyLaneAI/pennylane/issues/1722.
Here is an abstract example where a qnode
is executed and returns a value $g(x)$, which is then post-processed in $f(g)$.
@qml.qnode(dev)
def qnode(x):
..
qml.RX(x, 0)
..
return g
def f(g):
return non-linear(g)
When derivative of $f$ is given by $\partial_x f(g(x)) = f'(g(x)) \partial_x g(x) = f'(g(x)) (g(x+\pi/2) - g(x-\pi/2))/2$.
This is also what PennyLane is doing. However, there are situations where I want to be able to do something different. For example, in classical shadows, $g$ are some abstract bits
that inform the classical representation of the snapshot states. From this I can post-process the expectation value, $f$, which is a non-linear function of $g$, the bits
and recipes
(measurement choices).
When I call qml.grad(f)(x)
, PennyLane does the above logic (I am actually not sure what it computes for the gradients of the bits
and recipes
). But what I would want instead is
$\partial_x f = (f(g(x+\pi/2)) - f(g(x-\pi/2)))/2$
Solutions One solution is to perform the post-processing inside the qnode, i.e. registering a new measurement. As far as I understand, by default, calling grad on this measurement outcome does exactly what I want.
Instead of registering new measurements, it would be great to have the possibility to register custom functions as outputs of qnodes. That custom function would internally call a valid measurement, and post-process it to the desired output, which is then returned by the qnode.
A different solution would be to register custom vjps, but as far as I understand, this is not something very desirable.
Disclaimer While writing this I realized a lot of false conceptions on my end, so I am not 100% sure I am making sense. Please understand this issue more as a discussion forum rather than a contained and explicit feature request.
Instead of registering new measurements, it would be great to have the possibility to register custom functions as outputs of qnodes. That custom function would internally call a valid measurement, and post-process it to the desired output, which is then returned by the qnode.
I think a solution where we make it really easy for users and developers to define their own measurement process, which can be automatically used by devices without needing to modify devices, is the key thing we are missing
I am slightly confused by the chain rule not working here :sweat_smile:
Assuming f, g and x to be scalars, we have (as you wrote): d/dx f(g(x)) = f'(g(x)) * g'(x)
If the derivative of g
is given by the parameter-shift rule, a non-linear post-processing does not change that fact, i.e. g'(x) = 0.5[g(x+pi/2)-g(x-pi/2)]
. The derivative of f
has to be computed independently of the shift rule (as I take it that it's a classical function, it is available via autodiff or so?).
So my question would be whether you are interested in a quantity that is not the derivative d/dx f(g(x))
, or whether I am misunderstanding something here :)
The chain rule is working properly here, that is the problem 😄
Essentially, instead of $\partial_x f(g(x)) = f'(g(x)) \partial_x g(x) = f'(g(x)) (g(x+\pi/2) - g(x-\pi/2))/2$, I want $\partial_x f(g(x)) = (f(g(x+\pi/2)) - f(g(x-\pi/2)))/2$.
Or to put it in the context of classical shadows f -> expval
and g -> bits
, so I want to compute the derivative
d_x expval(bits(x)) = (expval(bits(x+pi/2) - expval(bits(x-pi/2))/2
When we do the post-processing inside the qnode, as far as I understand, g(x) is not tracked, so it is giving me exactly what I want.
To be honest when I originally wrote this I had a lot of misconceptions about what was going wrong, so this might be more specific to classical shadows than I thought.
A potential other case where you would want something like this is in count-counting type loss functions. So for example you want to disentangle an input state and make your loss function the number of 1s in your measurement statistics (i.e. something like this: https://iopscience.iop.org/article/10.1088/2632-2153/ac0616) Though I am not sure if this is actually a problem since g is linear here.
edit: Instead of non-/linearity, the property I was looking for should be rather whether or not it is a quantum function that get differentiated via parameter shift.
Ah, I see now :) So basically a variant of qml.gradients.param_shift
that treats the input function as a black box and just shifts the function arguments would be useful here? :thinking:
@josh146 I think this might be something nice to support, and very easy to code up. For QNodes that have a 1:1 mapping of QNode to tape arguments, this might speedup qml.gradients.param_shift
itself, in particular when used with jitted QNodes.
@dwierichs could you explain a bit further? It could just be the early morning fogging my brain, but I'm not sure I follow!
I mean that we could have the following:
def shift_it(fn, *params, shifts):
jac = np.zeros_like(params)
for idx in range(len(params)):
for coeff, mult, shift in shifts[idx]:
new_params = params.copy()
new_params[idx] = new_params[idx] * mult + shift
jac[idx] += fn(*new_params) * coeff
(which is untested pseudo-code) It doesn't know tapes, it doesn't know QNodes, it just knows callables and parameters, and some formatting of shift rules.
That sounds very promising! The heavy lifting would be in the shifts
argument here though, right? Or would this be a non-AD version where the user is expected to know the correct shifts?
This is actually how PL used to work! parameter-shift was done at the qfunc level, not the tape level 🙂
The reason we changed was that it ended up being very restricted --- it meant we couldn't include classical processing inside the QNode. This led to both:
- Issues with the user experience, users needed the ability to sometimes have classical processing inside a QNode
- Issues with decompositions; we couldn't use decompositions that led to classical processing.
I suppose in this case, we want something a bit different; rather than treating the qfunc as a black box, we want the black box to look like this:
flowchart LR
A(Differentiable tensor) ---> qnode
subgraph qnode [Blackbox VJP]
direction LR
tape ---> expval ---> B[measurement postprocessing]
end
qnode ---> C(Differentiable tensor)
Whereas we currently have this:
flowchart LR
A(Differentiable tensor) ---> qnode
subgraph qnode [Blackbox VJP]
direction LR
tape ---> expval
end
qnode ---> |Differentiable tensor| B[measurement postprocessing]
Yes, I think that's the right perspective...I'm not sure what the best way of proceeding is here.
The idea of registering callables instead of MeasurementProcesses
in QNodes is pretty cool, but I don't think it would be very maintainable...
If we were to use a black box function like in my previous comment, it could be part of the gradients
module core, and actually be used to compute the tape-based shift rule (applying it to a simple callable that produces a new tape copy and sets the parameters).
@josh146, as this seems to be "reverting" certain developments, do you think this would be nice to have? Or is there a better path towards supporting this?
The idea of registering callables instead of MeasurementProcesses in QNodes is pretty cool, but I don't think it would be very maintainable...
@dwierichs I'm curious why you think this? To me, this feels like the most natural solution, especially if we build it into the device API.
For a long time, we have actually wanted to:
-
Move the default implementations of
expval
,var
, etc. out ofQubitDevice
and into the correspondingMeasurementProcess
-
Allow devices to choose whether they use their own
expval
,var
, etc. logic, or utilize the logic provided by aMeasurementProcess
.
The advantage here is that any developers/user can define their own measurement process and have it be used by any device, without needing to modify the device. They also get gradient support via calling some_autodiff_framework.grad(...)
for free, even if they use NumPy
* 🙂 Currently, any time we want to add new measurement support to PL, we have to manually add it to every device we want to support as a new device method, which is a big overhead.
* Assuming their output of the measurement process supports the parameter-shift rule.
For me this sounds generally like a desirable feature to have and it could substantially clean up some of the devices codes. I have to admit I still lack insight to the PennyLane workflow as to predict what could be possible drawbacks or blockers for this though.
@josh146 Can you elaborate a bit more on how this would be structured? Would this mean the API is generally changing or just that e.g. qml.expval
in the return statement of the qnode is directly the Measurement Process and the qnode knows how to handle them?
It also sounds like a bigger operation, could it be worth it to implement @dwierichs idea in the meantime?