pennylane
pennylane copied to clipboard
[WIP] Create JVP capabilities and replace Jax custom VJP by custom JVP
Context:
Jax interfaces cannot be used using forward mode because it is using custom vjps which does not come with jvp support, on the other custom jvp allows to use vjps automatically.
Description of the Change:
This PRs gives the tools to use JVPs in PennyLane and replace the Jax-python custom_vjps by custom_jvps.
Benefits:
Jax interface support both forward and reverse propagation, it unlocks jacfwd and hessian functions.
Possible Drawbacks:
Related GitHub Issues:
Examples
@qml.qnode(device=dev, interface="jax", diff_method="parameter-shift")
def circuit_param_shift(x, y):
qml.RX(x, wires = 0)
qml.RY(y, wires = 0)
return qml.expval(qml.PauliZ(0))
a = jax.numpy.array(0.3)
b = jax.numpy.array(0.9)
print("Param shift")
print(jax.grad(circuit_param_shift)(a, b))
print(jax.jacfwd(circuit_param_shift)(a, b))
print(jax.jacrev(circuit_param_shift)(a, b))
print(jax.jacobian(circuit_param_shift)(a, b))
@qml.qnode(device=dev, interface="jax", diff_method="finite-diff")
def circuit_adjoint(x, y):
qml.RX(x, wires = 0)
qml.RY(y, wires = 0)
return qml.expval(qml.PauliZ(0))
print("Finite-diff")
print(jax.grad(circuit_adjoint)(a, b))
print(jax.jacfwd(circuit_adjoint)(a, b))
print(jax.jacrev(circuit_adjoint)(a, b))
print(jax.jacobian(circuit_adjoint)(a, b))
@qml.qnode(device=dev, interface="jax", diff_method="backprop")
def circuit_backprop(x, y):
qml.RX(x, wires = 0)
qml.RY(y, wires = 0)
return qml.expval(qml.PauliZ(0))
print("Backprop")
print(jax.grad(circuit_backprop)(a, b))
print(jax.jacfwd(circuit_backprop)(a, b))
print(jax.jacrev(circuit_backprop)(a, b))
print(jax.jacobian(circuit_backprop)(a, b))
Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:
- A one-to-two sentence description of the change. You may include a small working example for new features.
- A link back to this PR.
- Your name (or GitHub username) in the contributors section.
https://github.com/PennyLaneAI/pennylane/pull/3170