pennylane icon indicating copy to clipboard operation
pennylane copied to clipboard

Propagate the gradient to part of an array only

Open minhtriet opened this issue 8 months ago • 2 comments

Feature details

I want to optimize for the coordinates of different molecules in a reaction A+B->AB. It would make sense if I fixed coordinates of A and optimize for B's only. However, it doesn't work right now.

Suppose I have this JAX traced array

>>> coords
Traced<ConcreteArray([1. 1. 1.], dtype=float32)>with<JVPTrace(level=2/0)>
...

Then jnp.array([0, 0, 0, *coord]) won't work (within a pennylane context). The error is jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[2,3].

It is because at \qchem\openfermion_obs.py, we have this line geometry_dhf = qml.numpy.array(coordinates.reshape(len(symbols), 3)). At the end of the stack trace it would convert jax to a np array

Implementation

Here is my MVP to recreate the error. The issue is at line 17

import pennylane as qml
from pennylane import numpy as np
import jax
import optax

dev = qml.device("default.qubit", 4)

@qml.qnode(dev)
def circuit_expected(H):
    qml.BasisState([1, 1, 0, 0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(0.2, wires=[0, 1, 2, 3])
    return qml.expval(H)


def loss_f(coord):
    symbols = ["H", "H"]
    H, qb = qml.qchem.molecular_hamiltonian(symbols, jax.numpy.array([0, 0, 0, *coord]))
    return circuit_expected(H)

H_1 = jax.numpy.array([1., 1., 1.])
opt = optax.sgd(learning_rate=0.4)
opt_coords_state = opt.init(H_1)

for i in range(10):
    grad_coordinates = jax.grad(loss_f, 0)(H_1)
    updates, opt_coords_state = opt.update(grad_coordinates, opt_coords_state)
    H_1 = optax.apply_updates(H_1, updates)
    print(grad_coordinates)


How important would you say this feature is?

2: Somewhat important. Needed this quarter.

Additional information

No response

minhtriet avatar May 30 '24 07:05 minhtriet