catalyst
catalyst copied to clipboard
Fix compilation of Jacobian with for loops and vmap
Context: Some gradients/jacobians do not compile because the insertion point in the lowering is not correctly set.
Description of the Change: The insertion is the call op and not the body of the function.
Benefits: We can compile and run more derivatives of vmap and for loop.
Drawback It compiles gradient acting on vmap but returns wrong results, see xfailed test.
[sc-59758]
Hello. You may have forgotten to update the changelog!
Please edit doc/changelog.md on your branch with:
- A one-to-two sentence description of the change. You may include a small working example for new features.
- A link back to this PR.
- Your name (or GitHub username) in the contributors section.
Testing this fix on the following example results in very long (infinite?) compile or runtime:
import pennylane as qml
import jax
from jax import numpy as jnp
import catalyst
n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])
dev_name = "lightning.qubit"
dev = qml.device(dev_name, wires=n_wires)
@qml.qnode(dev)
def circuit(data, weights):
"""Quantum circuit ansatz"""
for i in range(n_wires):
qml.RY(data[i], wires=i)
for i in range(n_wires):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))
# try broadcasting
jit_circuit = catalyst.qjit(catalyst.vmap(circuit, in_axes = (1, None)))
def my_model(data, weights, bias):
# works with default.qubit
if dev_name == "default.qubit":
return circuit(data, weights) + bias
# works with lightning.qubit, not broadcasted
# return jnp.array([circuit(jnp.array(d), weights) for d in data.T])
# only works with loss_fn, fails at grad step
return jit_circuit(data, weights) + bias
@jax.jit
def loss_fn(params, data, targets):
predictions = my_model(data, params["weights"], params["bias"])
loss = jnp.sum((targets - predictions) ** 2 / len(data))
return loss
weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}
print(loss_fn(params, data, targets))
print(jax.grad(loss_fn)(params, data, targets)) # runs for > 20 minutes
[sc-49763]
Closes https://github.com/PennyLaneAI/catalyst/issues/294