catalyst
catalyst copied to clipboard
[Frontend] Remove special handling for Hamiltonians primitives
Once this issue in PL is resolved, apply the following patch:
diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py
index 0205381..f7bddb6 100644
--- a/frontend/catalyst/jax_tracer.py
+++ b/frontend/catalyst/jax_tracer.py
@@ -352,24 +352,6 @@ def trace_quantum_tape(
return out, qreg, qubit_states
-# TODO: remove once fixed upstream
-def trace_hamiltonian(coeffs, *nested_obs):
- """Trace a hamiltonian.
-
- Args:
- coeffs: a list of coefficients
- nested_obs: a list of the nested observables
-
- Returns:
- a hamiltonian JAX primitive used for tracing
- """
- # jprim.hamiltonian cannot take a python list as input
- # only as *args can a list be passed as an input.
- # Instead cast it as a JAX array.
- coeffs = jax.numpy.asarray(coeffs)
- return jprim.hamiltonian(coeffs, *nested_obs)
-
-
def trace_observables(obs, qubit_states, p, num_wires, qreg):
"""Trace observables.
@@ -405,7 +387,7 @@ def trace_observables(obs, qubit_states, p, num_wires, qreg):
jax_obs = jprim.tensorobs(*nested_obs)
elif isinstance(obs, qml.Hamiltonian):
nested_obs = [trace_observables(o, qubit_states, p, num_wires, qreg)[0] for o in obs.ops]
- jax_obs = trace_hamiltonian(op_args, *nested_obs)
+ jax_obs = jprim.hamiltonian(op_args, *nested_obs)
else:
raise RuntimeError(f"unknown observable in measurement process: {obs}")
return jax_obs, qubits
@dime10 it looks like making the proposed change in PL is going to be more difficult than expected and we should just remove the TODO comment. I'll leave this open to leave room for discussion, but I propose we just delete the TODO comment and keep the code as is.
We should think about resolving it though, if the fix we have is a problem for them it might also be a problem for us, and definitely will become so as we move forward with the integration.
@dime10, I'll add tests for the cases that Matthew mentioned and see if there are any errors.
Does the problem exist with jax.jit as well in any form?
I don't see why it would be a problem with jax.jit. The problem is regarding how JAX primitives work (i.e., they don't take python lists as inputs) and that we need to pass this python list to our hamiltonian primitive. Perhaps PL undergoes a different lowering process? But I might be wrong. I'll give it a try and see if I can trigger a bug with jax.jit when I add the test cases above.
I was not able to replicate the issue with jax.jit alone. Used the following tests which vary how the coefficients are originally specified by the user.
import jax
import numpy as np
import pennylane as qml
import pytest
from jax import numpy as jnp
import catalyst.utils.calculate_grad_shape as infer
from catalyst import CompileError, cond, for_loop, grad, qjit
from catalyst.pennylane_extensions import DifferentiableCompileError
@pytest.mark.parametrize("g_method", ["fd", "defer"])
@pytest.mark.parametrize(
"h_coeffs", [[0.2, -0.53], np.array([0.2, -0.53]), jnp.array([0.2, -0.53]), [np.array([0.2]), np.array([-0.53])], [jnp.array([0.2]), jnp.array([-0.53])]]
)
@pytest.mark.parametrize("inp", [([1.0, 2.0])])
def test_jax_consts(inp, h_coeffs, g_method, backend):
"""Test jax constants."""
@jax.jit
@qml.qnode(qml.device("default.qubit", wires=3))
def circuit(params):
qml.CRX(params[0], wires=[0, 1])
qml.CRX(params[0], wires=[0, 2])
h_obs = [qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.Hadamard(2)]
return qml.expval(qml.Hamiltonian(h_coeffs, h_obs))
circuit(jnp.array(inp))
Regarding the Catalyst and not PL, I added the tests for the cases [np.array([0.2]), np.array([-0.53])], [jnp.array([0.2]), jnp.array([-0.53])] to the test that checks the Hamiltonian coefficient and we handle these cases correctly as well. I do not see yet how keeping the current cast to a jax numpy array is problematic, can you explain what you see as an issue?
Just a note that Hamiltonian is marked for deprecation in PL, and is due to be replaced with https://docs.pennylane.ai/en/stable/code/api/pennylane.sum.html.
Does the issue still occur with qml.sum? If not, this is potentially a bugfix we may not want to prioritize
Thanks @josh146 . We definitely need to add tests for qml.sum then!
Thanks @erick-xanadu ! I think we can leave this as is for now.
We'll be adding support for qml.sum next quarter as well!
@dime10 not entirely sure if you still want the test cases. I added PR #171 , but if you still have concerns about how it is handled, feel free to close the PR and discuss later in the ticket.
EDIT: Never mind, the test is wrong. For some of these I am creating a list of lists.
~~@dime10 went back and thought a bit more about how to test this and found the following issue:~~
import jax
import numpy as np
import pennylane as qml
import pytest
from jax import numpy as jnp
import catalyst.utils.calculate_grad_shape as infer
from catalyst import CompileError, cond, for_loop, grad, qjit
@pytest.mark.parametrize(
"h_coeffs",
[
[0.2],
np.array([0.2]),
jnp.array([0.2]),
[np.array([0.2])],
[jnp.array([0.2])],
],
)
@pytest.mark.parametrize("inp", [([1.0, 2.0])])
def test_hamiltonian_with_computation(inp, h_coeffs, backend):
"""Test jax constants."""
def circuit(params):
qml.CRX(params[0], wires=[0, 1])
qml.CRX(params[0], wires=[0, 2])
h_coeffs_comp = [h_coeffs[0] + 1]
h_obs = [qml.PauliX(0) @ qml.PauliZ(1)]
return qml.expval(qml.Hamiltonian(h_coeffs_comp, h_obs))
@qjit(keep_intermediate=True)
def compile_grad(params):
g = qml.qnode(qml.device(backend, wires=3))(circuit)
h = grad(g, method="fd")
return h(params)
print(compile_grad(jnp.array(inp)))
For:
| input | result |
|---|---|
| [0.2] | success |
| np.array([0.2]) | success |
| jnp.array([0.2]) | The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=4/0)> |
| [np.array([0.2])] | See below |
| [jnp.array([0.2])] | The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=4/0)> |
/home/erick/code/catalyst-latest/frontend/test/pytest/compile_grad/compile_grad.nohlo.mlir:27:10: error: 'quantum.hamiltonian' op operand #0 must be 1D tensor of 64-bit float values or 1D memref of 64-bit float values, but got 'tensor<1x1xf64>'
%9 = "quantum.hamiltonian"(%arg0, %8) : (tensor<1x1xf64>, !quantum.obs) -> !quantum.obs
^
/home/erick/code/catalyst-latest/frontend/test/pytest/compile_grad/compile_grad.nohlo.mlir:27:10: note: see current operation: %18 = "quantum.hamiltonian"(%arg0, %17) : (tensor<1x1xf64>, !quantum.obs) -> !quantum.obs
@dime10, I strongly believe there is no error in this ticket. The error was introduced in the originally by changing the type that is passed as an input. Instead of a vector, I passed a matrix by mistake.
Then I also decided to look into how adding operations could affect the Hamiltonian's JAXPR and MLIR. However, I didn't remember that JAX arrays sometimes need different interfaces (.e.g., .at[0].add(1)).
We no longer plan on removing the conversion of coefficients to array, in fact we are using it in a few places now (see note in #171). Although the special function for hamiltonians has in fact been removed.