pennylane icon indicating copy to clipboard operation
pennylane copied to clipboard

Gradients of non-linear post-processing functions

Open Qottmann opened this issue 2 years ago • 11 comments

Context We have been stumbling over this problem during the classical shadows PRs https://github.com/PennyLaneAI/pennylane/pull/2871 as well as the counts https://github.com/PennyLaneAI/pennylane/issues/2932 . I believe it is also related to this older issue https://github.com/PennyLaneAI/pennylane/issues/1722.

Here is an abstract example where a qnode is executed and returns a value $g(x)$, which is then post-processed in $f(g)$.

@qml.qnode(dev)
def qnode(x):
    ..
    qml.RX(x, 0)
    ..
    return g

def f(g):
    return non-linear(g)

When derivative of $f$ is given by $\partial_x f(g(x)) = f'(g(x)) \partial_x g(x) = f'(g(x)) (g(x+\pi/2) - g(x-\pi/2))/2$.

This is also what PennyLane is doing. However, there are situations where I want to be able to do something different. For example, in classical shadows, $g$ are some abstract bits that inform the classical representation of the snapshot states. From this I can post-process the expectation value, $f$, which is a non-linear function of $g$, the bits and recipes (measurement choices).

When I call qml.grad(f)(x), PennyLane does the above logic (I am actually not sure what it computes for the gradients of the bits and recipes). But what I would want instead is

$\partial_x f = (f(g(x+\pi/2)) - f(g(x-\pi/2)))/2$

Solutions One solution is to perform the post-processing inside the qnode, i.e. registering a new measurement. As far as I understand, by default, calling grad on this measurement outcome does exactly what I want.

Instead of registering new measurements, it would be great to have the possibility to register custom functions as outputs of qnodes. That custom function would internally call a valid measurement, and post-process it to the desired output, which is then returned by the qnode.

A different solution would be to register custom vjps, but as far as I understand, this is not something very desirable.

Disclaimer While writing this I realized a lot of false conceptions on my end, so I am not 100% sure I am making sense. Please understand this issue more as a discussion forum rather than a contained and explicit feature request.

Qottmann avatar Aug 12 '22 20:08 Qottmann

Instead of registering new measurements, it would be great to have the possibility to register custom functions as outputs of qnodes. That custom function would internally call a valid measurement, and post-process it to the desired output, which is then returned by the qnode.

I think a solution where we make it really easy for users and developers to define their own measurement process, which can be automatically used by devices without needing to modify devices, is the key thing we are missing

josh146 avatar Aug 12 '22 21:08 josh146

I am slightly confused by the chain rule not working here :sweat_smile: Assuming f, g and x to be scalars, we have (as you wrote): d/dx f(g(x)) = f'(g(x)) * g'(x) If the derivative of g is given by the parameter-shift rule, a non-linear post-processing does not change that fact, i.e. g'(x) = 0.5[g(x+pi/2)-g(x-pi/2)]. The derivative of f has to be computed independently of the shift rule (as I take it that it's a classical function, it is available via autodiff or so?). So my question would be whether you are interested in a quantity that is not the derivative d/dx f(g(x)), or whether I am misunderstanding something here :)

dwierichs avatar Aug 23 '22 08:08 dwierichs

The chain rule is working properly here, that is the problem 😄

Essentially, instead of $\partial_x f(g(x)) = f'(g(x)) \partial_x g(x) = f'(g(x)) (g(x+\pi/2) - g(x-\pi/2))/2$, I want $\partial_x f(g(x)) = (f(g(x+\pi/2)) - f(g(x-\pi/2)))/2$.

Or to put it in the context of classical shadows f -> expval and g -> bits, so I want to compute the derivative d_x expval(bits(x)) = (expval(bits(x+pi/2) - expval(bits(x-pi/2))/2

When we do the post-processing inside the qnode, as far as I understand, g(x) is not tracked, so it is giving me exactly what I want.

To be honest when I originally wrote this I had a lot of misconceptions about what was going wrong, so this might be more specific to classical shadows than I thought.

A potential other case where you would want something like this is in count-counting type loss functions. So for example you want to disentangle an input state and make your loss function the number of 1s in your measurement statistics (i.e. something like this: https://iopscience.iop.org/article/10.1088/2632-2153/ac0616) Though I am not sure if this is actually a problem since g is linear here.

edit: Instead of non-/linearity, the property I was looking for should be rather whether or not it is a quantum function that get differentiated via parameter shift.

Qottmann avatar Aug 23 '22 16:08 Qottmann

Ah, I see now :) So basically a variant of qml.gradients.param_shift that treats the input function as a black box and just shifts the function arguments would be useful here? :thinking: @josh146 I think this might be something nice to support, and very easy to code up. For QNodes that have a 1:1 mapping of QNode to tape arguments, this might speedup qml.gradients.param_shift itself, in particular when used with jitted QNodes.

dwierichs avatar Aug 26 '22 08:08 dwierichs

@dwierichs could you explain a bit further? It could just be the early morning fogging my brain, but I'm not sure I follow!

josh146 avatar Aug 29 '22 13:08 josh146

I mean that we could have the following:

def shift_it(fn, *params, shifts):
    jac = np.zeros_like(params)
    for idx in range(len(params)):
        for coeff, mult, shift in shifts[idx]:
            new_params = params.copy()
            new_params[idx] = new_params[idx] * mult + shift
            jac[idx] += fn(*new_params) * coeff

(which is untested pseudo-code) It doesn't know tapes, it doesn't know QNodes, it just knows callables and parameters, and some formatting of shift rules.

dwierichs avatar Aug 29 '22 14:08 dwierichs

That sounds very promising! The heavy lifting would be in the shifts argument here though, right? Or would this be a non-AD version where the user is expected to know the correct shifts?

Qottmann avatar Aug 30 '22 16:08 Qottmann

This is actually how PL used to work! parameter-shift was done at the qfunc level, not the tape level 🙂

The reason we changed was that it ended up being very restricted --- it meant we couldn't include classical processing inside the QNode. This led to both:

  • Issues with the user experience, users needed the ability to sometimes have classical processing inside a QNode
  • Issues with decompositions; we couldn't use decompositions that led to classical processing.

I suppose in this case, we want something a bit different; rather than treating the qfunc as a black box, we want the black box to look like this:

flowchart LR
    A(Differentiable tensor)  ---> qnode
    subgraph qnode [Blackbox VJP]
        direction LR
        tape ---> expval ---> B[measurement postprocessing]
    end
    qnode ---> C(Differentiable tensor)

Whereas we currently have this:

flowchart LR
    A(Differentiable tensor)  ---> qnode
    subgraph qnode [Blackbox VJP]
        direction LR
        tape ---> expval
    end
    qnode ---> |Differentiable tensor| B[measurement postprocessing]

josh146 avatar Aug 30 '22 17:08 josh146

Yes, I think that's the right perspective...I'm not sure what the best way of proceeding is here. The idea of registering callables instead of MeasurementProcesses in QNodes is pretty cool, but I don't think it would be very maintainable... If we were to use a black box function like in my previous comment, it could be part of the gradients module core, and actually be used to compute the tape-based shift rule (applying it to a simple callable that produces a new tape copy and sets the parameters). @josh146, as this seems to be "reverting" certain developments, do you think this would be nice to have? Or is there a better path towards supporting this?

dwierichs avatar Aug 31 '22 10:08 dwierichs

The idea of registering callables instead of MeasurementProcesses in QNodes is pretty cool, but I don't think it would be very maintainable...

@dwierichs I'm curious why you think this? To me, this feels like the most natural solution, especially if we build it into the device API.

For a long time, we have actually wanted to:

  • Move the default implementations of expval, var, etc. out of QubitDevice and into the corresponding MeasurementProcess

  • Allow devices to choose whether they use their own expval, var, etc. logic, or utilize the logic provided by a MeasurementProcess.

The advantage here is that any developers/user can define their own measurement process and have it be used by any device, without needing to modify the device. They also get gradient support via calling some_autodiff_framework.grad(...) for free, even if they use NumPy* 🙂 Currently, any time we want to add new measurement support to PL, we have to manually add it to every device we want to support as a new device method, which is a big overhead.

* Assuming their output of the measurement process supports the parameter-shift rule.

josh146 avatar Aug 31 '22 13:08 josh146

For me this sounds generally like a desirable feature to have and it could substantially clean up some of the devices codes. I have to admit I still lack insight to the PennyLane workflow as to predict what could be possible drawbacks or blockers for this though.

@josh146 Can you elaborate a bit more on how this would be structured? Would this mean the API is generally changing or just that e.g. qml.expval in the return statement of the qnode is directly the Measurement Process and the qnode knows how to handle them?

It also sounds like a bigger operation, could it be worth it to implement @dwierichs idea in the meantime?

Qottmann avatar Sep 05 '22 16:09 Qottmann