pennylane
pennylane copied to clipboard
[BUG] Jitting with JAX produces incorrect results with parameter broadcasting
Expected behavior
Consider the following QNode:
dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev, diff_method="backprop", interface="jax")
def circuit(x):
qml.RX(x, wires=0)
qml.CNOT(wires=[0, 1])
return qml.expval(qml.PauliZ(1))
When the QNode is called with a parameter, the following should be output:
>>> x = jnp.array([0.3, 0.6])
>>> jax.jit(circuit)(x)
[0.9553365 0.8253356]
Actual behavior
Instead, the following is output:
>>> x = jnp.array([0.3, 0.6])
>>> jax.jit(circuit)(x)
[0.95533645 0.95533645]
It seems that only the first parameter is used for all calculations and the rest are ignored.
Additional information
If diff_method="parameter-shift" is set in the QNode, the issue disappears.
If jax.jit is removed in the above code, the issue disappears.
The issue also appears if the QNode is decorated with qml.transforms.batch_params.
Source code
No response
Tracebacks
No response
System information
Name: PennyLane
Version: 0.24.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: c:\users\edward.jiang\documents\pennylane
Requires: numpy, scipy, networkx, retworkx, autograd, toml, appdirs, semantic-version, autoray, cachetools, pennylane-lightning
Required-by: PennyLane-Lightning
Platform info: Windows-10-10.0.19042-SP0
Python version: 3.8.10
Numpy version: 1.22.3
Scipy version: 1.8.0
Installed devices:
- default.gaussian (PennyLane-0.24.0.dev0)
- default.mixed (PennyLane-0.24.0.dev0)
- default.mixed.autograd (PennyLane-0.24.0.dev0)
- default.qubit (PennyLane-0.24.0.dev0)
- default.qubit.autograd (PennyLane-0.24.0.dev0)
- default.qubit.jax (PennyLane-0.24.0.dev0)
- default.qubit.tf (PennyLane-0.24.0.dev0)
- default.qubit.torch (PennyLane-0.24.0.dev0)
- lightning.qubit (PennyLane-Lightning-0.23.0)
Existing GitHub issues
- [X] I have searched existing GitHub issues to make sure the issue does not already exist.
If the
qml.transforms.batch_paramsdecorator is removed, the issue stays. This gives reason to believe that the issue is with the newly introduced parameter broadcasting.
@eddddddy given that batch_params does not seem to be the cause of this issue, might be worth editing the title + description + main example to reflect this 🙂
Perhaps this is more related to the recent additions to JAX-JIT support than to parameter broadcasting? Not entirely sure 😕.
Hi @eddddddy,
Perhaps this is more related to the recent additions to JAX-JIT support than to parameter broadcasting? Not entirely sure confused.
as we noted this during our chat, the latest changes were added to the JAX-JIT interface and those changes should not be relevant for diff_method="backprop" (unless the interface is mistakenly being applied).
The idea behind interfaces is that
- they convert between NumPy and ML framework (i.e., JAX) objects to execute circuits on any device and
- they provide custom gradient rules.
Interfaces allow us to feed in ML objects into a QNode even when the target device would not natively understand the ML objects.
When we have diff_method="backprop", none of those are required because the entire computation is performed using the ML framework's objects (e.g., jnp.array objects). In such a case with the default.qubit device, under the hood, we would be swapping our device to default.qubit.jax.
Referencing this test here so we remember to change it once this is fixed.
Thanks for digging this one out! :+1: The reason seems to be, that qml.transforms.broadcast_expand is not JIT compatible.
edit: Actually, that's not quite true. The problem is that by default, qml.execute, which is called in QNode.__call__, makes use of qml.interfaces.cache_execute, which in turn caches the result and therefore does not produce multiple different, but only a single result, which then is retrieved from cache. I patched PL to print out the cache values and receive, when calling the example above:
>>> out = circuit(x)
cached values: [Traced<ShapedArray(float64[1])>with<DynamicJaxprTrace(level=0/1)>]
As we can see, only one value is computed and stored in the cache. I think this is because all traced tapes after applying broadcast_transform have the same hash, and I'm not sure it's possible to change that.
What next?
- With the introduction of parameter broadcasting to
DefaultQubit, this problem will be gone, but only forDefaultQubitdevices. In particular,DefaultMixedand basically all other devices will still have this problem. - This problem for example did not occur when the device's
batch_transformproduces multiple tapes because of a Hamiltonian decomposition. That's because the resulting terms differ in the measured observables and the tape hashes also differ (if they don't, the Hamiltonian would not be decomposed, I suppose) - I think we should deactivate caching when executing broadcasted tapes, but I am not sure how to do this best. One option would be to override (but not overwrite!) the execution kwarg
cacheinQNode.__call__if the QNode tape after construction has abatch_sizethat is not None, and if the QNode device does not support broadcasting (there is a device flag for this). @josh146 what do you think?
Also tagging @antalszava for visibility. Do you have an idea how to proceed best here? :)
I looked into this once more.
default.qubitdoes not longer produce the bug, because it now makes use of broadcasting. Replace it bydefault.mixedfor reproducing/testing purposes. This also reduces the urgency/severity of the bug I think.- The problem is that
qml.tape.QuantumTape.hashis not jit-compatible, because the parameters of the tape operations are converted to strings, making different calls with a tracer indistinguishable. - Taking
hash(tape)into account intape.hash, for example by simply including it in thefingerprintcreated inQuantumTape.hash, does resolve this bug, but it also prevents the correct recycling of cached results. - It seems that the best hotfix would be to internally deactivate caching in the particular use case of
a.
dev.capabilities()["supports_broadcasting"] = False, and b. using parameter-broadcasting, and c.diff_method = "backprop", and d.interface = "jax", and e. jitting, and f.cache = True(the default) in the QNode kwargs (so this would be overridden). - However, this hotfix might be unnecessary, because an immutable data structure replacing
QuantumTapeis on its way, and it is expected to not have this problem. - Side comment:
tape.hashis 3 orders of magnitude slower than the defaulthash(tape), with 200 µs vs 150 ns for a small tape. The latter does not work for our purposes, so we can not simply replace it, but it's good to keep this overhead in mind I think.
I can no longer reproduce this bug on master. @eddddddy, @josh146 , can you confirm?
@dwierichs if the minimal non-working example in the issue now works in master, I would say safe to close
Made the mistake of closing this, because it is fixed for "default.qubit", although I had mentioned already that this will happen :facepalm: :
With the introduction of parameter broadcasting to DefaultQubit, this problem will be gone, but only for DefaultQubit devices. In particular, DefaultMixed and basically all other devices will still have this problem.
I reopen the issue and modify the reported example to use "default.mixed" instead.