pennylane
pennylane copied to clipboard
Propagate the gradient to part of an array only
Feature details
I want to optimize for the coordinates of different molecules in a reaction A+B->AB
. It would make sense if I fixed coordinates of A
and optimize for B
's only. However, it doesn't work right now.
Suppose I have this JAX traced array
>>> coords
Traced<ConcreteArray([1. 1. 1.], dtype=float32)>with<JVPTrace(level=2/0)>
...
Then jnp.array([0, 0, 0, *coord])
won't work (within a pennylane context). The error is jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[2,3].
It is because at \qchem\openfermion_obs.py
, we have this line geometry_dhf = qml.numpy.array(coordinates.reshape(len(symbols), 3))
. At the end of the stack trace it would convert jax
to a np
array
Implementation
Here is my MVP to recreate the error. The issue is at line 17
import pennylane as qml
from pennylane import numpy as np
import jax
import optax
dev = qml.device("default.qubit", 4)
@qml.qnode(dev)
def circuit_expected(H):
qml.BasisState([1, 1, 0, 0], wires=[0, 1, 2, 3])
qml.DoubleExcitation(0.2, wires=[0, 1, 2, 3])
return qml.expval(H)
def loss_f(coord):
symbols = ["H", "H"]
H, qb = qml.qchem.molecular_hamiltonian(symbols, jax.numpy.array([0, 0, 0, *coord]))
return circuit_expected(H)
H_1 = jax.numpy.array([1., 1., 1.])
opt = optax.sgd(learning_rate=0.4)
opt_coords_state = opt.init(H_1)
for i in range(10):
grad_coordinates = jax.grad(loss_f, 0)(H_1)
updates, opt_coords_state = opt.update(grad_coordinates, opt_coords_state)
H_1 = optax.apply_updates(H_1, updates)
print(grad_coordinates)
How important would you say this feature is?
2: Somewhat important. Needed this quarter.
Additional information
No response