pennylane icon indicating copy to clipboard operation
pennylane copied to clipboard

[BUG] Jitting with JAX produces incorrect results with parameter broadcasting

Open eddddddy opened this issue 3 years ago • 6 comments

Expected behavior

Consider the following QNode:

dev = qml.device("default.qubit", wires=2)

@qml.qnode(dev, diff_method="backprop", interface="jax")
def circuit(x):
    qml.RX(x, wires=0)
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(1))

When the QNode is called with a parameter, the following should be output:

>>> x = jnp.array([0.3, 0.6])
>>> jax.jit(circuit)(x)
[0.9553365 0.8253356]

Actual behavior

Instead, the following is output:

>>> x = jnp.array([0.3, 0.6])
>>> jax.jit(circuit)(x)
[0.95533645 0.95533645]

It seems that only the first parameter is used for all calculations and the rest are ignored.

Additional information

If diff_method="parameter-shift" is set in the QNode, the issue disappears.

If jax.jit is removed in the above code, the issue disappears.

The issue also appears if the QNode is decorated with qml.transforms.batch_params.

Source code

No response

Tracebacks

No response

System information

Name: PennyLane
Version: 0.24.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: c:\users\edward.jiang\documents\pennylane
Requires: numpy, scipy, networkx, retworkx, autograd, toml, appdirs, semantic-version, autoray, cachetools, pennylane-lightning
Required-by: PennyLane-Lightning

Platform info:           Windows-10-10.0.19042-SP0
Python version:          3.8.10
Numpy version:           1.22.3
Scipy version:           1.8.0
Installed devices:
- default.gaussian (PennyLane-0.24.0.dev0)
- default.mixed (PennyLane-0.24.0.dev0)
- default.mixed.autograd (PennyLane-0.24.0.dev0)
- default.qubit (PennyLane-0.24.0.dev0)
- default.qubit.autograd (PennyLane-0.24.0.dev0)
- default.qubit.jax (PennyLane-0.24.0.dev0)
- default.qubit.tf (PennyLane-0.24.0.dev0)
- default.qubit.torch (PennyLane-0.24.0.dev0)
- lightning.qubit (PennyLane-Lightning-0.23.0)

Existing GitHub issues

  • [X] I have searched existing GitHub issues to make sure the issue does not already exist.

eddddddy avatar Jun 21 '22 17:06 eddddddy

If the qml.transforms.batch_params decorator is removed, the issue stays. This gives reason to believe that the issue is with the newly introduced parameter broadcasting.

@eddddddy given that batch_params does not seem to be the cause of this issue, might be worth editing the title + description + main example to reflect this 🙂

josh146 avatar Jun 22 '22 13:06 josh146

Perhaps this is more related to the recent additions to JAX-JIT support than to parameter broadcasting? Not entirely sure 😕.

eddddddy avatar Jun 22 '22 14:06 eddddddy

Hi @eddddddy,

Perhaps this is more related to the recent additions to JAX-JIT support than to parameter broadcasting? Not entirely sure confused.

as we noted this during our chat, the latest changes were added to the JAX-JIT interface and those changes should not be relevant for diff_method="backprop" (unless the interface is mistakenly being applied).

The idea behind interfaces is that

  1. they convert between NumPy and ML framework (i.e., JAX) objects to execute circuits on any device and
  2. they provide custom gradient rules.

Interfaces allow us to feed in ML objects into a QNode even when the target device would not natively understand the ML objects.

When we have diff_method="backprop", none of those are required because the entire computation is performed using the ML framework's objects (e.g., jnp.array objects). In such a case with the default.qubit device, under the hood, we would be swapping our device to default.qubit.jax.

antalszava avatar Jun 22 '22 14:06 antalszava

Referencing this test here so we remember to change it once this is fixed.

eddddddy avatar Jun 22 '22 16:06 eddddddy

Thanks for digging this one out! :+1: The reason seems to be, that qml.transforms.broadcast_expand is not JIT compatible.

edit: Actually, that's not quite true. The problem is that by default, qml.execute, which is called in QNode.__call__, makes use of qml.interfaces.cache_execute, which in turn caches the result and therefore does not produce multiple different, but only a single result, which then is retrieved from cache. I patched PL to print out the cache values and receive, when calling the example above:

>>> out = circuit(x)
cached values: [Traced<ShapedArray(float64[1])>with<DynamicJaxprTrace(level=0/1)>]

As we can see, only one value is computed and stored in the cache. I think this is because all traced tapes after applying broadcast_transform have the same hash, and I'm not sure it's possible to change that.

What next?

  1. With the introduction of parameter broadcasting to DefaultQubit, this problem will be gone, but only for DefaultQubit devices. In particular, DefaultMixed and basically all other devices will still have this problem.
  2. This problem for example did not occur when the device's batch_transform produces multiple tapes because of a Hamiltonian decomposition. That's because the resulting terms differ in the measured observables and the tape hashes also differ (if they don't, the Hamiltonian would not be decomposed, I suppose)
  3. I think we should deactivate caching when executing broadcasted tapes, but I am not sure how to do this best. One option would be to override (but not overwrite!) the execution kwarg cache in QNode.__call__ if the QNode tape after construction has a batch_size that is not None, and if the QNode device does not support broadcasting (there is a device flag for this). @josh146 what do you think?

Also tagging @antalszava for visibility. Do you have an idea how to proceed best here? :)

dwierichs avatar Jun 23 '22 07:06 dwierichs

I looked into this once more.

  1. default.qubit does not longer produce the bug, because it now makes use of broadcasting. Replace it by default.mixed for reproducing/testing purposes. This also reduces the urgency/severity of the bug I think.
  2. The problem is that qml.tape.QuantumTape.hash is not jit-compatible, because the parameters of the tape operations are converted to strings, making different calls with a tracer indistinguishable.
  3. Taking hash(tape) into account in tape.hash, for example by simply including it in the fingerprint created in QuantumTape.hash, does resolve this bug, but it also prevents the correct recycling of cached results.
  4. It seems that the best hotfix would be to internally deactivate caching in the particular use case of a. dev.capabilities()["supports_broadcasting"] = False , and b. using parameter-broadcasting, and c. diff_method = "backprop", and d. interface = "jax", and e. jitting, and f. cache = True (the default) in the QNode kwargs (so this would be overridden).
  5. However, this hotfix might be unnecessary, because an immutable data structure replacing QuantumTape is on its way, and it is expected to not have this problem.
  6. Side comment: tape.hash is 3 orders of magnitude slower than the default hash(tape), with 200 µs vs 150 ns for a small tape. The latter does not work for our purposes, so we can not simply replace it, but it's good to keep this overhead in mind I think.

dwierichs avatar Aug 03 '22 11:08 dwierichs

I can no longer reproduce this bug on master. @eddddddy, @josh146 , can you confirm?

dwierichs avatar Feb 09 '23 12:02 dwierichs

@dwierichs if the minimal non-working example in the issue now works in master, I would say safe to close

josh146 avatar Feb 09 '23 16:02 josh146

Made the mistake of closing this, because it is fixed for "default.qubit", although I had mentioned already that this will happen :facepalm: :

With the introduction of parameter broadcasting to DefaultQubit, this problem will be gone, but only for DefaultQubit devices. In particular, DefaultMixed and basically all other devices will still have this problem.

I reopen the issue and modify the reported example to use "default.mixed" instead.

dwierichs avatar Feb 21 '23 09:02 dwierichs