catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[Frontend] Remove special handling for Hamiltonians primitives

Open erick-xanadu opened this issue 2 years ago • 12 comments

Once this issue in PL is resolved, apply the following patch:

diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py
index 0205381..f7bddb6 100644
--- a/frontend/catalyst/jax_tracer.py
+++ b/frontend/catalyst/jax_tracer.py
@@ -352,24 +352,6 @@ def trace_quantum_tape(
     return out, qreg, qubit_states
 
 
-# TODO: remove once fixed upstream
-def trace_hamiltonian(coeffs, *nested_obs):
-    """Trace a hamiltonian.
-
-    Args:
-        coeffs: a list of coefficients
-        nested_obs: a list of the nested observables
-
-    Returns:
-        a hamiltonian JAX primitive used for tracing
-    """
-    # jprim.hamiltonian cannot take a python list as input
-    # only as *args can a list be passed as an input.
-    # Instead cast it as a JAX array.
-    coeffs = jax.numpy.asarray(coeffs)
-    return jprim.hamiltonian(coeffs, *nested_obs)
-
-
 def trace_observables(obs, qubit_states, p, num_wires, qreg):
     """Trace observables.
 
@@ -405,7 +387,7 @@ def trace_observables(obs, qubit_states, p, num_wires, qreg):
         jax_obs = jprim.tensorobs(*nested_obs)
     elif isinstance(obs, qml.Hamiltonian):
         nested_obs = [trace_observables(o, qubit_states, p, num_wires, qreg)[0] for o in obs.ops]
-        jax_obs = trace_hamiltonian(op_args, *nested_obs)
+        jax_obs = jprim.hamiltonian(op_args, *nested_obs)
     else:
         raise RuntimeError(f"unknown observable in measurement process: {obs}")
     return jax_obs, qubits

erick-xanadu avatar Jun 16 '23 17:06 erick-xanadu

@dime10 it looks like making the proposed change in PL is going to be more difficult than expected and we should just remove the TODO comment. I'll leave this open to leave room for discussion, but I propose we just delete the TODO comment and keep the code as is.

erick-xanadu avatar Jun 20 '23 18:06 erick-xanadu

We should think about resolving it though, if the fix we have is a problem for them it might also be a problem for us, and definitely will become so as we move forward with the integration.

dime10 avatar Jun 20 '23 19:06 dime10

@dime10, I'll add tests for the cases that Matthew mentioned and see if there are any errors.

erick-xanadu avatar Jun 20 '23 19:06 erick-xanadu

Does the problem exist with jax.jit as well in any form?

dime10 avatar Jun 20 '23 19:06 dime10

I don't see why it would be a problem with jax.jit. The problem is regarding how JAX primitives work (i.e., they don't take python lists as inputs) and that we need to pass this python list to our hamiltonian primitive. Perhaps PL undergoes a different lowering process? But I might be wrong. I'll give it a try and see if I can trigger a bug with jax.jit when I add the test cases above.

erick-xanadu avatar Jun 20 '23 20:06 erick-xanadu

I was not able to replicate the issue with jax.jit alone. Used the following tests which vary how the coefficients are originally specified by the user.

import jax                 
import numpy as np
import pennylane as qml
import pytest             
from jax import numpy as jnp            

import catalyst.utils.calculate_grad_shape as infer
from catalyst import CompileError, cond, for_loop, grad, qjit
from catalyst.pennylane_extensions import DifferentiableCompileError
                   
@pytest.mark.parametrize("g_method", ["fd", "defer"])
@pytest.mark.parametrize(
    "h_coeffs", [[0.2, -0.53], np.array([0.2, -0.53]), jnp.array([0.2, -0.53]), [np.array([0.2]), np.array([-0.53])], [jnp.array([0.2]), jnp.array([-0.53])]]
)
@pytest.mark.parametrize("inp", [([1.0, 2.0])])      
def test_jax_consts(inp, h_coeffs, g_method, backend):              
    """Test jax constants."""
    
    @jax.jit       
    @qml.qnode(qml.device("default.qubit", wires=3))
    def circuit(params):
        qml.CRX(params[0], wires=[0, 1])
        qml.CRX(params[0], wires=[0, 2])
        h_obs = [qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.Hadamard(2)]
        return qml.expval(qml.Hamiltonian(h_coeffs, h_obs))
        
    circuit(jnp.array(inp))

Regarding the Catalyst and not PL, I added the tests for the cases [np.array([0.2]), np.array([-0.53])], [jnp.array([0.2]), jnp.array([-0.53])] to the test that checks the Hamiltonian coefficient and we handle these cases correctly as well. I do not see yet how keeping the current cast to a jax numpy array is problematic, can you explain what you see as an issue?

erick-xanadu avatar Jun 20 '23 20:06 erick-xanadu

Just a note that Hamiltonian is marked for deprecation in PL, and is due to be replaced with https://docs.pennylane.ai/en/stable/code/api/pennylane.sum.html.

Does the issue still occur with qml.sum? If not, this is potentially a bugfix we may not want to prioritize

josh146 avatar Jun 20 '23 20:06 josh146

Thanks @josh146 . We definitely need to add tests for qml.sum then!

erick-xanadu avatar Jun 20 '23 21:06 erick-xanadu

Thanks @erick-xanadu ! I think we can leave this as is for now. We'll be adding support for qml.sum next quarter as well!

dime10 avatar Jun 20 '23 23:06 dime10

@dime10 not entirely sure if you still want the test cases. I added PR #171 , but if you still have concerns about how it is handled, feel free to close the PR and discuss later in the ticket.

erick-xanadu avatar Jun 21 '23 13:06 erick-xanadu

EDIT: Never mind, the test is wrong. For some of these I am creating a list of lists.

~~@dime10 went back and thought a bit more about how to test this and found the following issue:~~

import jax
import numpy as np
import pennylane as qml
import pytest
from jax import numpy as jnp

import catalyst.utils.calculate_grad_shape as infer
from catalyst import CompileError, cond, for_loop, grad, qjit


@pytest.mark.parametrize(
    "h_coeffs",
    [
        [0.2],
        np.array([0.2]),
        jnp.array([0.2]),
        [np.array([0.2])],
        [jnp.array([0.2])],
    ],
)
@pytest.mark.parametrize("inp", [([1.0, 2.0])])
def test_hamiltonian_with_computation(inp, h_coeffs, backend):
    """Test jax constants."""

    def circuit(params):
        qml.CRX(params[0], wires=[0, 1])
        qml.CRX(params[0], wires=[0, 2])
        h_coeffs_comp = [h_coeffs[0] + 1]
        h_obs = [qml.PauliX(0) @ qml.PauliZ(1)]
        return qml.expval(qml.Hamiltonian(h_coeffs_comp, h_obs))

    @qjit(keep_intermediate=True)
    def compile_grad(params):
        g = qml.qnode(qml.device(backend, wires=3))(circuit)
        h = grad(g, method="fd")
        return h(params)

    print(compile_grad(jnp.array(inp)))

For:

input result
[0.2] success
np.array([0.2]) success
jnp.array([0.2]) The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=4/0)>
[np.array([0.2])] See below
[jnp.array([0.2])] The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=4/0)>
/home/erick/code/catalyst-latest/frontend/test/pytest/compile_grad/compile_grad.nohlo.mlir:27:10: error: 'quantum.hamiltonian' op operand #0 must be 1D tensor of 64-bit float values or 1D memref of 64-bit float values, but got 'tensor<1x1xf64>'
    %9 = "quantum.hamiltonian"(%arg0, %8) : (tensor<1x1xf64>, !quantum.obs) -> !quantum.obs
         ^
/home/erick/code/catalyst-latest/frontend/test/pytest/compile_grad/compile_grad.nohlo.mlir:27:10: note: see current operation: %18 = "quantum.hamiltonian"(%arg0, %17) : (tensor<1x1xf64>, !quantum.obs) -> !quantum.obs

erick-xanadu avatar Jun 21 '23 14:06 erick-xanadu

@dime10, I strongly believe there is no error in this ticket. The error was introduced in the originally by changing the type that is passed as an input. Instead of a vector, I passed a matrix by mistake.

Then I also decided to look into how adding operations could affect the Hamiltonian's JAXPR and MLIR. However, I didn't remember that JAX arrays sometimes need different interfaces (.e.g., .at[0].add(1)).

erick-xanadu avatar Jun 22 '23 15:06 erick-xanadu

We no longer plan on removing the conversion of coefficients to array, in fact we are using it in a few places now (see note in #171). Although the special function for hamiltonians has in fact been removed.

dime10 avatar Jul 12 '24 16:07 dime10