pennylane copied to clipboard
Propagate the gradient to part of an array only
Feature details
I want to optimize for the coordinates of different molecules in a reaction A+B->AB
. It would make sense if I fixed coordinates of A
and optimize for B
's only. However, it doesn't work right now.
Suppose I have this JAX traced array
>>> coords
Traced<ConcreteArray([1. 1. 1.], dtype=float32)>with<JVPTrace(level=2/0)>
Then jnp.array([0, 0, 0, *coord])
won't work (within a pennylane context). The error is jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[2,3].
It is because at \qchem\
, we have this line geometry_dhf = qml.numpy.array(coordinates.reshape(len(symbols), 3))
. At the end of the stack trace it would convert jax
to a np
Here is my MVP to recreate the error. The issue is at line 17
import pennylane as qml
from pennylane import numpy as np
import jax
import optax
dev = qml.device("default.qubit", 4)
def circuit_expected(H):
qml.BasisState([1, 1, 0, 0], wires=[0, 1, 2, 3])
qml.DoubleExcitation(0.2, wires=[0, 1, 2, 3])
return qml.expval(H)
def loss_f(coord):
symbols = ["H", "H"]
H, qb = qml.qchem.molecular_hamiltonian(symbols, jax.numpy.array([0, 0, 0, *coord]))
return circuit_expected(H)
H_1 = jax.numpy.array([1., 1., 1.])
opt = optax.sgd(learning_rate=0.4)
opt_coords_state = opt.init(H_1)
for i in range(10):
grad_coordinates = jax.grad(loss_f, 0)(H_1)
updates, opt_coords_state = opt.update(grad_coordinates, opt_coords_state)
H_1 = optax.apply_updates(H_1, updates)
How important would you say this feature is?
2: Somewhat important. Needed this quarter.
Additional information
No response