pennylane
pennylane copied to clipboard
Make qchem jax compatible
Context:
Autograd deprecation means we are moving to using JAX for auto-differentiation. Since the requires_grad keyword is not supported by JAX arrays, we need a different solution.
Description of the Change:
We keep backwards compatibility to pnp by checking what interface the user is using. We check this by checking the interface of the tensor. If they are using autograd, then we stick to the old workflow and check requires_grad using getattr().
If the user inputs a jax array for any of [coordinates, coefficients, alpha], we assume that the user wants to use JAX and define all undefined coeffs/alphas using jax arrays. This means that if a user decides to mix pnp with jax, we don't hard cast the rest into either since we can't make a decision; therefore it'll result in a warning about mixing these two.
WHEN USING JAX:
If users wish to differentiate any one of these parameters they should mark the parameter they want differentiable using the JAX UI, e.g. jax.grad(..., argnums = <indice(s) of differentiable parameter(s)>)(*args). In our case, due to technical limitations, *args must always be exactly[coordinates, coefficients, alpha]. No other order is allowed and you cannot omit any of them. Unfortunately, this also includes when you are NOT using jax.grad or a jax function, like when you define diff_hamiltonian(...)(*args), the args here (if using JAX) also need to be exactly [coordinates, coefficients, alpha]. When you do decide to differentiate, doing jax.grad(..., argnums=1)(coordinates, coefficients, alpha) would mean you want coefficients to be differentiable. Note this is a departure from the UI of qml.grad, where you could do qml.grad(..., argnum=0)(coefficients) instead.
Additional notes:
UI for qml.grad and all the other stuff is unchanged for autograd and pnp users. However, if you are using JAX and trying to use the args keyword in molecular_hamiltonian and related hamiltonians, you will need to define all of [coordinates, coefficients, alpha] as well since it goes downstream to diff_hamiltonian(...)(*args).
Benefits: Now JAX compatible.
Possible Drawbacks: More changes may be needed to JIT, may have performance issues. Different UI for qml.grad and jax.grad. Different expectations for args keyword for jax arrays and pnp arrays.
Related GitHub Issues: [sc-69776] [sc-69778]
Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:
- A one-to-two sentence description of the change. You may include a small working example for new features.
- A link back to this PR.
- Your name (or GitHub username) in the contributors section.
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 99.64%. Comparing base (
c5fd5bc) to head (705395a). Report is 285 commits behind head on master.
Additional details and impacted files
@@ Coverage Diff @@
## master #6096 +/- ##
==========================================
- Coverage 99.64% 99.64% -0.01%
==========================================
Files 469 468 -1
Lines 44331 44065 -266
==========================================
- Hits 44173 43907 -266
Misses 158 158
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
🚀 New features to boost your workflow:
- ❄ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
I think we can add the warning/error for
64bitprecision viajaxin this PR itself.
Let's make this a separate PR. We'll have to add test cases as well and I don't want this PR to be that big.
Hey @austingmhuang, just wondering here how come you are adding argnum to the various qchem functionalities?
I would expect that we simply need to change the logic in the qchem functions to use qml.math.requires_grad(array), which should support all interfaces:
That is, the user only passes argnum to jax.grad or jax.jacobian, and internally code uses qml.math.requires_grad
Hey @austingmhuang, just wondering here how come you are adding
argnumto the various qchem functionalities?
For functions like overlap_matrix, we need to know what's differentiable in order to know what to do. So functionally, doing qml.qchem.overlap_matrix() will need that information somehow.
I would expect that we simply need to change the logic in the qchem functions to use qml.math.requires_grad(array), which should support all interfaces:
Uhm, I'm a bit confused here. In this situation, everything would be non-differentiable and we would just have a function that sets the differentiability based on whatever is set in argnums from some jax function. For something like qml.qchem.diff_hamiltonian(), how would I specify that I want to differentiate specifically the geometry, for example? I'm probably missing something (I'm not that well-versed with jax), but it's not clear to me how I could do that with the example you have there 😅
jax.jacobian(qml.qchem.diff_hamiltonian(mol, cutoff, core, active), argnums=???)(?)
For functions like
overlap_matrix, we need to know what's differentiable in order to know what to do. So functionally, doingqml.qchem.overlap_matrix()will need that information somehow.
Yep, and we should be able to do this via qml.math.requires_grad, so that it is automatically inferred from the parameters :)
For something like
qml.qchem.diff_hamiltonian(), how would I specify that I want to differentiate specifically the geometry, for example?
The typical approach is that, if a user is calling a workflow via jax.grad(..., argnums=...)(*args), then it's the jax.grad function (via argnums) that determines which parameters are differentiable! So the user would not be specifying argnums on internal functions within the workflow, this should be extracted automatically via qml.math.requires_grad.
The typical approach is that, if a user is calling a workflow via
jax.grad(..., argnums=...)(*args), then it's thejax.gradfunction (viaargnums) that determines which parameters are differentiable! So the user would not be specifyingargnumson internal functions within the workflow, this should be extracted automatically viaqml.math.requires_grad.
So in your example, you have a dummy function f that has two variables x and y, and you can select which one you want to differentiate by specifying an index, 0 for x, 1 for y. If I wanted to differentiate only x, I could do: jax.jacobian(f, argnums=0)(x, y)
But suppose the dummy function instead takes in a tuple (x, y) and I wanted to differentiate only the x in that tuple, can I do that using jax?
In qchem we sort of have this situation going here since our molecule information is all stored in the Molecule class, rather than as separate parameters in the function signature. So it's sort of a similar situation to the above example.
In qchem we sort of have this situation going here since our molecule information is all stored in the Molecule class, rather than as separate parameters in the function signature. So it's sort of a similar situation to the above example.
Yep, I think this might be a sign we need to revisit this design 🤔 (cc @soranjh)