catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

Fix compilation of Jacobian with for loops and vmap

Open rmoyard opened this issue 2 years ago • 4 comments

Context: Some gradients/jacobians do not compile because the insertion point in the lowering is not correctly set.

Description of the Change: The insertion is the call op and not the body of the function.

Benefits: We can compile and run more derivatives of vmap and for loop.

Drawback It compiles gradient acting on vmap but returns wrong results, see xfailed test.

[sc-59758]

rmoyard avatar Oct 25 '23 01:10 rmoyard

Hello. You may have forgotten to update the changelog! Please edit doc/changelog.md on your branch with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

github-actions[bot] avatar Mar 19 '24 14:03 github-actions[bot]

Testing this fix on the following example results in very long (infinite?) compile or runtime:

import pennylane as qml
import jax
from jax import numpy as jnp
import catalyst

n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev_name = "lightning.qubit"
dev = qml.device(dev_name, wires=n_wires)

@qml.qnode(dev)
def circuit(data, weights):
    """Quantum circuit ansatz"""

    for i in range(n_wires):
        qml.RY(data[i], wires=i)

    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

# try broadcasting
jit_circuit = catalyst.qjit(catalyst.vmap(circuit, in_axes = (1, None)))

def my_model(data, weights, bias):
    # works with default.qubit
    if dev_name == "default.qubit":
        return circuit(data, weights) + bias

    # works with lightning.qubit, not broadcasted
    # return jnp.array([circuit(jnp.array(d), weights) for d in data.T])

    # only works with loss_fn, fails at grad step
    return jit_circuit(data, weights) + bias

@jax.jit
def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss


weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

print(loss_fn(params, data, targets))
print(jax.grad(loss_fn)(params, data, targets))  # runs for > 20 minutes

dime10 avatar May 07 '24 15:05 dime10

[sc-49763]

rmoyard avatar Jul 04 '24 14:07 rmoyard

Closes https://github.com/PennyLaneAI/catalyst/issues/294

rmoyard avatar Jul 23 '24 17:07 rmoyard