pennylane
pennylane copied to clipboard
[BUG] `ConcretizationTypeError` when unnecessary `work_wires` is specified.
Expected behavior
Hi,
I was testing another issue that required work_wires when I found that there are some quantum programs that will not be able to be jax.jited when using work_wires even if work_wires is not required. Please note that the example submitted, when we remove the unneeded keyword argument work_wires=[7] it succeeds in being jax.jited.
Actual behavior
ConcretizationTypeError is raised.
Additional information
No response
Source code
import pennylane as qml
import jax
@jax.jit
@qml.qnode(qml.device("lightning.qubit", wires=8))
def circuit(x : int):
op = qml.Identity(wires=[0])
op2 = qml.ctrl(op, control=[x], work_wires=[7])
qml.matrix(op2)
return qml.state()
print(circuit(1))
Tracebacks
Traceback (most recent call last):
File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 266, in fn
return self.tape_fn(obj.expand(), *args, **kwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 301, in tape_fn
return self._tape_fn(obj, *args, **kwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/ops/functions/matrix.py", line 134, in _matrix
raise qml.operation.MatrixUndefinedError
pennylane.operation.MatrixUndefinedError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/erick.ochoalopez/Code/catalyst-latest/test.py", line 12, in <module>
print(circuit(1))
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/api.py", line 306, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
ans = call(fun, *args)
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/erick.ochoalopez/Code/pennylane/pennylane/qnode.py", line 975, in __call__
self.construct(args, kwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/qnode.py", line 872, in construct
self._tape = make_qscript(self.func, shots)(*args, **kwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/tape/qscript.py", line 1492, in wrapper
result = fn(*args, **kwargs)
File "/home/erick.ochoalopez/Code/catalyst-latest/test.py", line 9, in circuit
qml.matrix(op2)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 213, in __call__
return self._create_wrapper(obj, *targs, **tkwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 412, in _create_wrapper
wrapper = self.fn(obj, *targs, **tkwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 275, in fn
raise e1 from e
File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 260, in fn
return self._fn(obj, *args, **kwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/ops/functions/matrix.py", line 127, in matrix
return op.matrix(wire_order=wire_order)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/ops/op_math/controlled.py", line 433, in matrix
return qml.math.expand_matrix(
File "/home/erick.ochoalopez/Code/pennylane/pennylane/math/matrix_manipulation.py", line 134, in expand_matrix
wire_indices = [wire_order.index(wire) for wire in wires]
File "/home/erick.ochoalopez/Code/pennylane/pennylane/math/matrix_manipulation.py", line 134, in <listcomp>
wire_indices = [wire_order.index(wire) for wire in wires]
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/core.py", line 667, in __bool__
def __bool__(self): return self.aval._bool(self)
File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/core.py", line 1370, in error
raise ConcretizationTypeError(arg, fname_context)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[].
The problem arose with the `bool` function.
The error occurred while tracing the function circuit at /home/erick.ochoalopez/Code/catalyst-latest/test.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/erick.ochoalopez/Code/catalyst-latest/test.py", line 12, in <module>
print(circuit(1))
File "/home/erick.ochoalopez/Code/pennylane/pennylane/qnode.py", line 975, in __call__
self.construct(args, kwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/qnode.py", line 872, in construct
self._tape = make_qscript(self.func, shots)(*args, **kwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/tape/qscript.py", line 1492, in wrapper
result = fn(*args, **kwargs)
File "/home/erick.ochoalopez/Code/catalyst-latest/test.py", line 9, in circuit
qml.matrix(op2)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 213, in __call__
return self._create_wrapper(obj, *targs, **tkwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 412, in _create_wrapper
wrapper = self.fn(obj, *targs, **tkwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 275, in fn
raise e1 from e
File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 260, in fn
return self._fn(obj, *args, **kwargs)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/ops/functions/matrix.py", line 127, in matrix
return op.matrix(wire_order=wire_order)
File "/home/erick.ochoalopez/Code/pennylane/pennylane/ops/op_math/controlled.py", line 433, in matrix
return qml.math.expand_matrix(
File "/home/erick.ochoalopez/Code/pennylane/pennylane/math/matrix_manipulation.py", line 134, in expand_matrix
wire_indices = [wire_order.index(wire) for wire in wires]
File "/home/erick.ochoalopez/Code/pennylane/pennylane/math/matrix_manipulation.py", line 134, in <listcomp>
wire_indices = [wire_order.index(wire) for wire in wires]
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[].
The problem arose with the `bool` function.
The error occurred while tracing the function circuit at /home/erick.ochoalopez/Code/catalyst-latest/test.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
System information
>>> qml.about()
Name: PennyLane
Version: 0.32.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages
Editable project location: /home/erick.ochoalopez/Code/pennylane
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: pennylane-catalyst, PennyLane-Lightning
Existing GitHub issues
- [X] I have searched existing GitHub issues to make sure the issue does not already exist.
@erick-xanadu Shouldn't we be treating wire labels are static metadata/ compile time constant?
@erick-xanadu Shouldn't we be treating wire labels are static metadata/ compile time constant?
@albi3ro what would the motivation be for this? What sets your intuition on dynamic variables vs static constants?
Previously we've thought about "things that potentially trainable" like any TensorLike as the dynamic variables. That assumption is baked into both how we write things and how we test things.
When we allow a variable to be abstract, we strongly limit the number of things we can do with it. We can no longer use it with control flow.
I also don't think we have a single test of for abstract wires.
I would be open to allowing to wires to be dynamic, but we would need time to adjust the assumptions in our code, add tests, and work through all the problems that will inevitably come up.
Hey all, just revisiting this issue now (with newer context we might have from plxpr work).
Will plxpr be treating wires as dynamic or static?
This does work with qjit btw :)
The posted example no longer raises an error. Can we close this now?
It shouldn't work though... That makes me kinda worried.
OK it started working becuase matrix started de-queueing the controlled op.
It now gives an informative error:
WireError: Cannot run circuit(s) on lightning.qubit as abstract wires are present in the tape: Wires([Traced<~int32[]>with<DynamicJaxprTrace>, 0]). Abstract wires are not yet supported.
if the matrix is removed.