catalyst
catalyst copied to clipboard
[BUG] Passing observables as parameters triggers an exception
import pennylane as qml
import pytest
from catalyst import qjit
def test_observable_as_parameter(backend):
"""Test to see if we can pass an observable parameter to qfunc."""
coeffs0 = [0.3, -5.1]
H0 = qml.Hamiltonian(qml.math.array(coeffs0), [qml.PauliZ(0), qml.PauliY(1)])
@qjit
def circuit(obs):
return qml.expval(obs)
circuit(H0)
Taking H0 as an observable will cause line 254 in compilation_pipelines.py to fail.
I suspect this is related to pytrees.
237 @staticmethod
238 def get_runtime_signature(*args):
239 """Get signature from arguments.
240
241 Args:
242 *args: arguments to the compiled function
243
244 Returns:
245 a list of JAX shaped arrays
246 """
247 args_data, args_shape = tree_flatten(args)
248
249 try:
250 r_sig = []
251 for arg in args_data:
252 r_sig.append(jax.api_util.shaped_abstractify(arg))
253 # Unflatten JAX abstracted args to preserve the shape
254 return tree_unflatten(args_shape, r_sig)
255 except Exception as exc:
256 arg_type = type(arg)
257 raise TypeError(f"Unsupported argument type: {arg_type}") from exc
This is the exception that is triggered (before being caught immediately after in line 255):
TypeError: float() argument must be a string or a real number, not 'ShapedArray'
This is because the Hamiltonian unflatten function will attempt to build a Hamiltonian object with a ShapedArray.
Relevant discussion in JAX: https://github.com/google/jax/discussions/18291
Regardless of whether we fix this upstream by modifying the Hamiltonian class constructor, a safe resolution would be to avoid calling tree_unflatten with ShapedArrays as we do in our code.
This would also avoid similar issues with other custom PyTree classes not in our control.