JAX support for control flow in jitted quantum functions
Feature details
In #1676, @ankit27kh noted that jitted quantum functions that contain control flow (or jax.lax.cond) raise an error.
This issue is present even on the forward pass.
import jax
import jaxopt
jax.config.update("jax_enable_x64", True)
import pennylane as qml
dev = qml.device("default.qubit", wires=1, shots=None)
@jax.jit
@qml.qnode(dev, interface="jax")
def energy(a):
jax.lax.cond(a >0.3, lambda _: qml.PauliX(wires=0), lambda _: qml.PauliZ(wires=0), None)
return qml.expval(qml.PauliZ(0))
energy(jnp.array(0.3))
TypeError: Value PauliX(wires=[0]) with type <class 'pennylane.ops.qubit.non_parametric_ops.PauliX'> is not a valid JAX type
The issue here is more specific to JAX: jax.lax.cond requires functions that return JAX objects. Therefore, support for such cases would have to be added.
Implementation
We would likely have to create a new function for classical conditional statements that depend on jax.lax.cond.
How important would you say this feature is?
1: Not important. Would be nice to have.
Additional information
No response
A while ago, I ran into the same issue. I tried to solve it by registering the Operation class as a PyTree, but couldn't get it working 🤔
The error in this example is due to the use of lambda functions which return the quantum gates. Normally we only construct quantum gate objects without ever using the "return value" explicitly. If we instead use functions that return None, the error doesn't appear:
def if_then():
qml.PauliX(wires=0)
def if_else():
qml.PauliZ(wires=0)
@jax.jit
@qml.qnode(dev, interface="jax")
def energy(a):
jax.lax.cond(a > 0.3, if_then, if_else)
return qml.expval(qml.PauliZ(0))
energy(jnp.array(1))
The main issue however is that JAX will trace both branches of the conational, resulting in a tape that is the sum of gates from both branches.
Nice catch @dime10! I wonder how we can get around the double tracing issue, but my guess is it is highly non-trivial, and we need a way to represent quantum instructions with conditionals