pennylane
pennylane copied to clipboard
[BUG] Jitting `Projector` gives `ValueError`
Expected behavior
I expect to be able to run the below code sample without error.
The answer should be the same as the non-jit alternative.
Actual behavior
I get the error included below.
Additional information
No response
Source code
import jax
@qml.qnode(qml.device('default.qubit'))
def circuit(state):
return qml.expval(qml.Projector(state, wires=0))
jax.jit(circuit)(np.array([0]))
Tracebacks
---------------------------------------------------------------------------
TracerArrayConversionError Traceback (most recent call last)
File ~/Prog/pennylane/pennylane/math/single_dispatch.py:737, in _to_numpy_jax(x)
736 try:
--> 737 return np.array(getattr(x, "val", x))
738 except TracerArrayConversionError as e:
File ~/Prog/pl311/lib/python3.11/site-packages/jax/_src/core.py:710, in Tracer.__array__(self, *args, **kw)
709 def __array__(self, *args, **kw):
--> 710 raise TracerArrayConversionError(self)
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[1].
The error occurred while tracing the function circuit at /var/folders/k1/0v_kvphn55lgf_45kntf1hqm0000gq/T/ipykernel_38839/71224146.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument state.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Cell In[4], line 7
3 @qml.qnode(qml.device('default.qubit'))
4 def circuit(state):
5 return qml.expval(qml.Projector(state, wires=0))
----> 7 jax.jit(circuit)(np.array([0]))
[... skipping hidden 12 frame]
File ~/Prog/pennylane/pennylane/qnode.py:976, in QNode.__call__(self, *args, **kwargs)
973 kwargs["shots"] = _get_device_shots(self._original_device)
975 # construct the tape
--> 976 self.construct(args, kwargs)
978 cache = self.execute_kwargs.get("cache", False)
979 using_custom_cache = (
980 hasattr(cache, "__getitem__")
981 and hasattr(cache, "__setitem__")
982 and hasattr(cache, "__delitem__")
983 )
File ~/Prog/pennylane/pennylane/qnode.py:862, in QNode.construct(self, args, kwargs)
859 self.interface = qml.math.get_interface(*args, *list(kwargs.values()))
861 with qml.queuing.AnnotatedQueue() as q:
--> 862 self._qfunc_output = self.func(*args, **kwargs)
864 self._tape = QuantumScript.from_queue(q, shots)
866 params = self.tape.get_parameters(trainable_only=False)
Cell In[4], line 5, in circuit(state)
3 @qml.qnode(qml.device('default.qubit'))
4 def circuit(state):
----> 5 return qml.expval(qml.Projector(state, wires=0))
File ~/Prog/pennylane/pennylane/ops/qubit/observables.py:439, in BasisStateProjector.__init__(self, state, wires, id)
437 def __init__(self, state, wires, id=None):
438 wires = Wires(wires)
--> 439 state = list(qml.math.toarray(state).astype(int))
441 if not set(state).issubset({0, 1}):
442 raise ValueError(f"Basis state must only consist of 0s and 1s; got {state}")
File ~/Prog/pl311/lib/python3.11/site-packages/autoray/autoray.py:80, in do(fn, like, *args, **kwargs)
31 """Do function named ``fn`` on ``(*args, **kwargs)``, peforming single
32 dispatch to retrieve ``fn`` based on whichever library defines the class of
33 the ``args[0]``, or the ``like`` keyword argument if specified.
(...)
77 <tf.Tensor: id=91, shape=(3, 3), dtype=float32>
78 """
79 backend = choose_backend(fn, *args, like=like, **kwargs)
---> 80 return get_lib_fn(backend, fn)(*args, **kwargs)
File ~/Prog/pennylane/pennylane/math/single_dispatch.py:739, in _to_numpy_jax(x)
737 return np.array(getattr(x, "val", x))
738 except TracerArrayConversionError as e:
--> 739 raise ValueError(
740 "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
741 ) from e
ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.
System information
PL master
Existing GitHub issues
- [X] I have searched existing GitHub issues to make sure the issue does not already exist.
Could this be fixed in a similar way to #4966?
This can be solved by bypassing the check that converts the inputted state to a list.
https://github.com/PennyLaneAI/pennylane/blob/3fd3c0acac6bfacbaa499d0e2bec429a66769e78/pennylane/ops/qubit/observables.py#L439
Honestly, not quite sure why we are always converting the basis_state to a list.
Also needs a couple of other fixes that overlap with #4971 .
Heyy @albi3ro , is there anything i can help with this?
@AnuravModak . I've assigned you.
I think 🤞 that we should be able to fix this by:
- Removing the casting to list linked above. This may cause some tests to break, but I think we should be ok in breaking that behavior.
- Only perform the validation check if the state isn't abstract (
qml.math.is_abstract) - Potentially update
compute_eigvalsto be jittable as well.
Feel free to leave any additional things that crop up either here or on a PR.
Hey @albi3ro, could you please confirm if moving forward with this approach is the right call or if there's a better direction to take? Thanks!
class BasisStateProjector(qml.ops.Projector, qml.operation.Observable):
def __init__(self, state, wires, id=None):
wires = qml.wires.Wires(wires)
state_array = qml.math.toarray(state)
if not qml.math.is_abstract(state):
state_array = state_array.astype(int)
if not set(state_array).issubset({0, 1}):
raise ValueError(f"Basis state must only consist of 0s and 1s; got {state_array}")
super().__init__(state_array, wires=wires, id=id)
@jax.jit
def compute_matrix(self):
return BasisStateProjector.compute_matrix(self.parameters[0])
@staticmethod
@jax.jit
def compute_matrix(basis_state):
m = np.zeros((2 ** len(basis_state), 2 ** len(basis_state)))
idx = int("".join(str(i) for i in basis_state), 2)
m[idx, idx] = 1
return m
@jax.jit
def compute_eigvals(self):
return BasisStateProjector.compute_eigvals(self.parameters[0])
@staticmethod
@jax.jit
def compute_eigvals(basis_state):
w = np.zeros(2 ** len(basis_state))
idx = int("".join(str(i) for i in basis_state), 2)
w[idx] = 1
return w
@jax.jit
def compute_diagonalizing_gates(self, wires):
return BasisStateProjector.compute_diagonalizing_gates(self.parameters[0], wires)
@staticmethod
@jax.jit
def compute_diagonalizing_gates(basis_state, wires):
return []
One restriction for PennyLane is that we try to make it compatible with all interfaces, but also compatible with any subset of installed ML interfaces. So if the user doesn't have jax, or just isn't using it, everything should work fine. We can't to import jax unless we already know the user has jax installed and wants to use it.
One way we do that is with dispatching with the qml.math module. This is basically a wrapper around autoray
>>> import jax
>>> x = jax.numpy.array(0.5)
>>> qml.math.get_interface(x)
'jax'
>>> qml.math.is_abstract(x) # checking if jitting
False
Alongside various other numpy-like functions.
For example, you can see us include tensorflow-specific logic like:
https://github.com/PennyLaneAI/pennylane/blob/f53b70dd62205e288ceb7d2c17b56cb7965bdfa9/pennylane/ops/qubit/parametric_ops_multi_qubit.py#L150
Or bypass jit-incompatible logic:
https://github.com/PennyLaneAI/pennylane/blob/f53b70dd62205e288ceb7d2c17b56cb7965bdfa9/pennylane/ops/qubit/matrix_ops.py#L138
The current implementation is doing two main things that are incompatible with jitting:
- Converting to numpy and then a list. This looks to be fixed in your implementation. We can just leave the input as is.
https://github.com/PennyLaneAI/pennylane/blob/f53b70dd62205e288ceb7d2c17b56cb7965bdfa9/pennylane/ops/qubit/observables.py#L439
- Updating a matrix in place:
https://github.com/PennyLaneAI/pennylane/blob/f53b70dd62205e288ceb7d2c17b56cb7965bdfa9/pennylane/ops/qubit/observables.py#L499
As you can see the the jax documentation sharp bits, in-place updates are forbidden with jax. Instead they recommend the jax_array.at[1, :].set(1.0) syntax.
I would probably check if the interface is jax, and use their recommended approach to updating an existing array in that case.
The line:
int("".join(str(i) for i in basis_state), 2)
might also not be jit-friendly, as it converts to strings and back. This may need to be converted to a purely mathematical function.
Hope that helps.