pennylane icon indicating copy to clipboard operation
pennylane copied to clipboard

[BUG] Jitting `Projector` gives `ValueError`

Open albi3ro opened this issue 1 year ago • 6 comments

Expected behavior

I expect to be able to run the below code sample without error.

The answer should be the same as the non-jit alternative.

Actual behavior

I get the error included below.

Additional information

No response

Source code

import jax

@qml.qnode(qml.device('default.qubit'))
def circuit(state):
    return qml.expval(qml.Projector(state, wires=0))

jax.jit(circuit)(np.array([0]))

Tracebacks

---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
File ~/Prog/pennylane/pennylane/math/single_dispatch.py:737, in _to_numpy_jax(x)
    736 try:
--> 737     return np.array(getattr(x, "val", x))
    738 except TracerArrayConversionError as e:

File ~/Prog/pl311/lib/python3.11/site-packages/jax/_src/core.py:710, in Tracer.__array__(self, *args, **kw)
    709 def __array__(self, *args, **kw):
--> 710   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[1].
The error occurred while tracing the function circuit at /var/folders/k1/0v_kvphn55lgf_45kntf1hqm0000gq/T/ipykernel_38839/71224146.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument state.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[4], line 7
      3 @qml.qnode(qml.device('default.qubit'))
      4 def circuit(state):
      5     return qml.expval(qml.Projector(state, wires=0))
----> 7 jax.jit(circuit)(np.array([0]))

    [... skipping hidden 12 frame]

File ~/Prog/pennylane/pennylane/qnode.py:976, in QNode.__call__(self, *args, **kwargs)
    973         kwargs["shots"] = _get_device_shots(self._original_device)
    975 # construct the tape
--> 976 self.construct(args, kwargs)
    978 cache = self.execute_kwargs.get("cache", False)
    979 using_custom_cache = (
    980     hasattr(cache, "__getitem__")
    981     and hasattr(cache, "__setitem__")
    982     and hasattr(cache, "__delitem__")
    983 )

File ~/Prog/pennylane/pennylane/qnode.py:862, in QNode.construct(self, args, kwargs)
    859     self.interface = qml.math.get_interface(*args, *list(kwargs.values()))
    861 with qml.queuing.AnnotatedQueue() as q:
--> 862     self._qfunc_output = self.func(*args, **kwargs)
    864 self._tape = QuantumScript.from_queue(q, shots)
    866 params = self.tape.get_parameters(trainable_only=False)

Cell In[4], line 5, in circuit(state)
      3 @qml.qnode(qml.device('default.qubit'))
      4 def circuit(state):
----> 5     return qml.expval(qml.Projector(state, wires=0))

File ~/Prog/pennylane/pennylane/ops/qubit/observables.py:439, in BasisStateProjector.__init__(self, state, wires, id)
    437 def __init__(self, state, wires, id=None):
    438     wires = Wires(wires)
--> 439     state = list(qml.math.toarray(state).astype(int))
    441     if not set(state).issubset({0, 1}):
    442         raise ValueError(f"Basis state must only consist of 0s and 1s; got {state}")

File ~/Prog/pl311/lib/python3.11/site-packages/autoray/autoray.py:80, in do(fn, like, *args, **kwargs)
     31 """Do function named ``fn`` on ``(*args, **kwargs)``, peforming single
     32 dispatch to retrieve ``fn`` based on whichever library defines the class of
     33 the ``args[0]``, or the ``like`` keyword argument if specified.
   (...)
     77     <tf.Tensor: id=91, shape=(3, 3), dtype=float32>
     78 """
     79 backend = choose_backend(fn, *args, like=like, **kwargs)
---> 80 return get_lib_fn(backend, fn)(*args, **kwargs)

File ~/Prog/pennylane/pennylane/math/single_dispatch.py:739, in _to_numpy_jax(x)
    737     return np.array(getattr(x, "val", x))
    738 except TracerArrayConversionError as e:
--> 739     raise ValueError(
    740         "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
    741     ) from e

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

System information

PL master

Existing GitHub issues

  • [X] I have searched existing GitHub issues to make sure the issue does not already exist.

albi3ro avatar Jan 02 '24 18:01 albi3ro

Could this be fixed in a similar way to #4966?

trbromley avatar Jan 03 '24 15:01 trbromley

This can be solved by bypassing the check that converts the inputted state to a list.

https://github.com/PennyLaneAI/pennylane/blob/3fd3c0acac6bfacbaa499d0e2bec429a66769e78/pennylane/ops/qubit/observables.py#L439

Honestly, not quite sure why we are always converting the basis_state to a list.

Also needs a couple of other fixes that overlap with #4971 .

albi3ro avatar Jan 03 '24 18:01 albi3ro

Heyy @albi3ro , is there anything i can help with this?

AnuravModak avatar Jan 09 '24 05:01 AnuravModak

@AnuravModak . I've assigned you.

I think 🤞 that we should be able to fix this by:

  1. Removing the casting to list linked above. This may cause some tests to break, but I think we should be ok in breaking that behavior.
  2. Only perform the validation check if the state isn't abstract (qml.math.is_abstract)
  3. Potentially update compute_eigvals to be jittable as well.

Feel free to leave any additional things that crop up either here or on a PR.

albi3ro avatar Jan 09 '24 16:01 albi3ro

Hey @albi3ro, could you please confirm if moving forward with this approach is the right call or if there's a better direction to take? Thanks!

class BasisStateProjector(qml.ops.Projector, qml.operation.Observable):
 

    def __init__(self, state, wires, id=None):
        wires = qml.wires.Wires(wires)
        state_array = qml.math.toarray(state)

        if not qml.math.is_abstract(state):
            state_array = state_array.astype(int)
            if not set(state_array).issubset({0, 1}):
                raise ValueError(f"Basis state must only consist of 0s and 1s; got {state_array}")

        super().__init__(state_array, wires=wires, id=id)

    @jax.jit
    def compute_matrix(self):
        return BasisStateProjector.compute_matrix(self.parameters[0])

    @staticmethod
    @jax.jit
    def compute_matrix(basis_state):
        m = np.zeros((2 ** len(basis_state), 2 ** len(basis_state)))
        idx = int("".join(str(i) for i in basis_state), 2)
        m[idx, idx] = 1
        return m

    @jax.jit
    def compute_eigvals(self):
        return BasisStateProjector.compute_eigvals(self.parameters[0])

    @staticmethod
    @jax.jit
    def compute_eigvals(basis_state):
        w = np.zeros(2 ** len(basis_state))
        idx = int("".join(str(i) for i in basis_state), 2)
        w[idx] = 1
        return w

    @jax.jit
    def compute_diagonalizing_gates(self, wires):
        return BasisStateProjector.compute_diagonalizing_gates(self.parameters[0], wires)

    @staticmethod
    @jax.jit
    def compute_diagonalizing_gates(basis_state, wires):
        return []

AnuravModak avatar Jan 21 '24 19:01 AnuravModak

One restriction for PennyLane is that we try to make it compatible with all interfaces, but also compatible with any subset of installed ML interfaces. So if the user doesn't have jax, or just isn't using it, everything should work fine. We can't to import jax unless we already know the user has jax installed and wants to use it.

One way we do that is with dispatching with the qml.math module. This is basically a wrapper around autoray

>>> import jax
>>> x = jax.numpy.array(0.5)
>>> qml.math.get_interface(x)
'jax'
>>> qml.math.is_abstract(x) # checking if jitting
False

Alongside various other numpy-like functions.

For example, you can see us include tensorflow-specific logic like:

https://github.com/PennyLaneAI/pennylane/blob/f53b70dd62205e288ceb7d2c17b56cb7965bdfa9/pennylane/ops/qubit/parametric_ops_multi_qubit.py#L150

Or bypass jit-incompatible logic:

https://github.com/PennyLaneAI/pennylane/blob/f53b70dd62205e288ceb7d2c17b56cb7965bdfa9/pennylane/ops/qubit/matrix_ops.py#L138

The current implementation is doing two main things that are incompatible with jitting:

  • Converting to numpy and then a list. This looks to be fixed in your implementation. We can just leave the input as is.

https://github.com/PennyLaneAI/pennylane/blob/f53b70dd62205e288ceb7d2c17b56cb7965bdfa9/pennylane/ops/qubit/observables.py#L439

  • Updating a matrix in place:

https://github.com/PennyLaneAI/pennylane/blob/f53b70dd62205e288ceb7d2c17b56cb7965bdfa9/pennylane/ops/qubit/observables.py#L499

As you can see the the jax documentation sharp bits, in-place updates are forbidden with jax. Instead they recommend the jax_array.at[1, :].set(1.0) syntax.

I would probably check if the interface is jax, and use their recommended approach to updating an existing array in that case.

The line:

int("".join(str(i) for i in basis_state), 2)

might also not be jit-friendly, as it converts to strings and back. This may need to be converted to a purely mathematical function.

Hope that helps.

albi3ro avatar Jan 22 '24 14:01 albi3ro