pennylane icon indicating copy to clipboard operation
pennylane copied to clipboard

Make qchem jax compatible

Open austingmhuang opened this issue 1 year ago • 8 comments

Context: Autograd deprecation means we are moving to using JAX for auto-differentiation. Since the requires_grad keyword is not supported by JAX arrays, we need a different solution.

Description of the Change: We keep backwards compatibility to pnp by checking what interface the user is using. We check this by checking the interface of the tensor. If they are using autograd, then we stick to the old workflow and check requires_grad using getattr().

If the user inputs a jax array for any of [coordinates, coefficients, alpha], we assume that the user wants to use JAX and define all undefined coeffs/alphas using jax arrays. This means that if a user decides to mix pnp with jax, we don't hard cast the rest into either since we can't make a decision; therefore it'll result in a warning about mixing these two.

WHEN USING JAX: If users wish to differentiate any one of these parameters they should mark the parameter they want differentiable using the JAX UI, e.g. jax.grad(..., argnums = <indice(s) of differentiable parameter(s)>)(*args). In our case, due to technical limitations, *args must always be exactly[coordinates, coefficients, alpha]. No other order is allowed and you cannot omit any of them. Unfortunately, this also includes when you are NOT using jax.grad or a jax function, like when you define diff_hamiltonian(...)(*args), the args here (if using JAX) also need to be exactly [coordinates, coefficients, alpha]. When you do decide to differentiate, doing jax.grad(..., argnums=1)(coordinates, coefficients, alpha) would mean you want coefficients to be differentiable. Note this is a departure from the UI of qml.grad, where you could do qml.grad(..., argnum=0)(coefficients) instead.

Additional notes: UI for qml.grad and all the other stuff is unchanged for autograd and pnp users. However, if you are using JAX and trying to use the args keyword in molecular_hamiltonian and related hamiltonians, you will need to define all of [coordinates, coefficients, alpha] as well since it goes downstream to diff_hamiltonian(...)(*args).

Benefits: Now JAX compatible.

Possible Drawbacks: More changes may be needed to JIT, may have performance issues. Different UI for qml.grad and jax.grad. Different expectations for args keyword for jax arrays and pnp arrays.

Related GitHub Issues: [sc-69776] [sc-69778]

austingmhuang avatar Aug 13 '24 16:08 austingmhuang

Hello. You may have forgotten to update the changelog! Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

github-actions[bot] avatar Aug 13 '24 16:08 github-actions[bot]

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 99.64%. Comparing base (c5fd5bc) to head (705395a). Report is 285 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #6096      +/-   ##
==========================================
- Coverage   99.64%   99.64%   -0.01%     
==========================================
  Files         469      468       -1     
  Lines       44331    44065     -266     
==========================================
- Hits        44173    43907     -266     
  Misses        158      158              

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Aug 14 '24 16:08 codecov[bot]

I think we can add the warning/error for 64bit precision via jax in this PR itself.

Let's make this a separate PR. We'll have to add test cases as well and I don't want this PR to be that big.

austingmhuang avatar Sep 03 '24 13:09 austingmhuang

Hey @austingmhuang, just wondering here how come you are adding argnum to the various qchem functionalities?

I would expect that we simply need to change the logic in the qchem functions to use qml.math.requires_grad(array), which should support all interfaces:

image

That is, the user only passes argnum to jax.grad or jax.jacobian, and internally code uses qml.math.requires_grad

josh146 avatar Sep 04 '24 14:09 josh146

Hey @austingmhuang, just wondering here how come you are adding argnum to the various qchem functionalities?

For functions like overlap_matrix, we need to know what's differentiable in order to know what to do. So functionally, doing qml.qchem.overlap_matrix() will need that information somehow.

I would expect that we simply need to change the logic in the qchem functions to use qml.math.requires_grad(array), which should support all interfaces:

Uhm, I'm a bit confused here. In this situation, everything would be non-differentiable and we would just have a function that sets the differentiability based on whatever is set in argnums from some jax function. For something like qml.qchem.diff_hamiltonian(), how would I specify that I want to differentiate specifically the geometry, for example? I'm probably missing something (I'm not that well-versed with jax), but it's not clear to me how I could do that with the example you have there 😅

jax.jacobian(qml.qchem.diff_hamiltonian(mol, cutoff, core, active), argnums=???)(?)

austingmhuang avatar Sep 04 '24 16:09 austingmhuang

For functions like overlap_matrix, we need to know what's differentiable in order to know what to do. So functionally, doing qml.qchem.overlap_matrix() will need that information somehow.

Yep, and we should be able to do this via qml.math.requires_grad, so that it is automatically inferred from the parameters :)

For something like qml.qchem.diff_hamiltonian(), how would I specify that I want to differentiate specifically the geometry, for example?

The typical approach is that, if a user is calling a workflow via jax.grad(..., argnums=...)(*args), then it's the jax.grad function (via argnums) that determines which parameters are differentiable! So the user would not be specifying argnums on internal functions within the workflow, this should be extracted automatically via qml.math.requires_grad.

josh146 avatar Sep 04 '24 18:09 josh146

The typical approach is that, if a user is calling a workflow via jax.grad(..., argnums=...)(*args), then it's the jax.grad function (via argnums) that determines which parameters are differentiable! So the user would not be specifying argnums on internal functions within the workflow, this should be extracted automatically via qml.math.requires_grad.

So in your example, you have a dummy function f that has two variables x and y, and you can select which one you want to differentiate by specifying an index, 0 for x, 1 for y. If I wanted to differentiate only x, I could do: jax.jacobian(f, argnums=0)(x, y)

But suppose the dummy function instead takes in a tuple (x, y) and I wanted to differentiate only the x in that tuple, can I do that using jax?

In qchem we sort of have this situation going here since our molecule information is all stored in the Molecule class, rather than as separate parameters in the function signature. So it's sort of a similar situation to the above example.

austingmhuang avatar Sep 05 '24 13:09 austingmhuang

In qchem we sort of have this situation going here since our molecule information is all stored in the Molecule class, rather than as separate parameters in the function signature. So it's sort of a similar situation to the above example.

Yep, I think this might be a sign we need to revisit this design 🤔 (cc @soranjh)

josh146 avatar Sep 05 '24 22:09 josh146