catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[BUG] Passing observables as parameters triggers an exception

Open erick-xanadu opened this issue 2 years ago • 2 comments

import pennylane as qml
import pytest
from catalyst import qjit


def test_observable_as_parameter(backend):
    """Test to see if we can pass an observable parameter to qfunc."""

    coeffs0 = [0.3, -5.1]
    H0 = qml.Hamiltonian(qml.math.array(coeffs0), [qml.PauliZ(0), qml.PauliY(1)])

    @qjit
    def circuit(obs):
        return qml.expval(obs)

    circuit(H0)

Taking H0 as an observable will cause line 254 in compilation_pipelines.py to fail. I suspect this is related to pytrees.

237     @staticmethod
238     def get_runtime_signature(*args):
239         """Get signature from arguments.
240 
241         Args:
242             *args: arguments to the compiled function
243 
244         Returns:
245             a list of JAX shaped arrays
246         """
247         args_data, args_shape = tree_flatten(args)
248 
249         try:
250             r_sig = []
251             for arg in args_data:
252                 r_sig.append(jax.api_util.shaped_abstractify(arg))
253             # Unflatten JAX abstracted args to preserve the shape
254             return tree_unflatten(args_shape, r_sig)
255         except Exception as exc:
256             arg_type = type(arg)
257             raise TypeError(f"Unsupported argument type: {arg_type}") from exc

This is the exception that is triggered (before being caught immediately after in line 255):

TypeError: float() argument must be a string or a real number, not 'ShapedArray'

This is because the Hamiltonian unflatten function will attempt to build a Hamiltonian object with a ShapedArray.

erick-xanadu avatar Oct 26 '23 15:10 erick-xanadu

Relevant discussion in JAX: https://github.com/google/jax/discussions/18291

erick-xanadu avatar Oct 26 '23 19:10 erick-xanadu

Regardless of whether we fix this upstream by modifying the Hamiltonian class constructor, a safe resolution would be to avoid calling tree_unflatten with ShapedArrays as we do in our code. This would also avoid similar issues with other custom PyTree classes not in our control.

dime10 avatar Nov 08 '23 21:11 dime10