dev = qml.device("lightning.qubit", wires=2)
def my_quantum_transform(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.QuantumTape], Callable):
tape1 = tape
tape2 = qml.tape.QuantumTape(tape.operations, tape.measurements)
def post_processing_fn(results):
return results[0] + results[1]
return [tape1, tape2], post_processing_fn
dispatched_transform = qml.transform(my_quantum_transform)
@qml.qnode(dev)
def circuit():
@catalyst.for_loop(0, 1, 1)
def loop0(_, yy):
qml.RX(3.14, wires=0)
return yy + 2
loop0(0)
return qml.expval(qml.X(0))
circuit = dispatched_transform(circuit)
circuit = qjit(circuit)
print("qjit results: ", circuit())
>>>
Traceback (most recent call last):
File "/home/paul.wang/catalyst_new/catalyst/multi_tape.py", line 144, in <module>
circuit = qjit(circuit)
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 376, in qjit
return QJIT(fn, CompileOptions(**kwargs))
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
output = func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 443, in __init__
self.aot_compile()
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 481, in aot_compile
self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
return fn(*args, **kwargs)
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 606, in capture
jaxpr, out_type, treedef = trace_to_jaxpr(
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 531, in trace_to_jaxpr
jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 604, in fn_with_transform_named_sequence
return self.user_function(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 584, in closure
return QFunc.__call__(qnode, *args, **dict(params, **kwargs))
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/qfunc.py", line 165, in __call__
res_flat = func_p.bind(flattened_fun, *args_flat, fn=self)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/qfunc.py", line 143, in _eval_quantum
closed_jaxpr, out_type, out_tree, out_tree_exp = trace_quantum_function(
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 1162, in trace_quantum_function
qrp_out = trace_quantum_operations(tape, device, qreg_in, ctx, trace, mcm_config)
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 655, in trace_quantum_operations
qrp2 = op.trace_quantum(ctx, device, trace, qrp, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/api_extensions/control_flow.py", line 1209, in trace_quantum
op.bind_overwrite_classical_tracers(
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 460, in bind_overwrite_classical_tracers
out_quantum_tracer = self.binder(*in_expanded_tracers, **kwargs)[-1]
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/tracing.py", line 969, in bind
source_info = jax_current()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError: ({ lambda ; a:i64[] b:i64[] c:AbstractQreg() d:AbstractQreg(). let
e:i64[] = add b 2
f:AbstractQbit() = qextract c 0
g:AbstractQbit() = qinst[
adjoint=False
ctrl_len=0
op=RX
params_len=1
qubits_len=1
] f 3.14
_:AbstractQreg() = qinsert c 0 g
h:AbstractQbit() = qextract d 0
i:AbstractQbit() = qinst[
adjoint=False
ctrl_len=0
op=RX
params_len=1
qubits_len=1
] h 3.14
j:AbstractQreg() = qinsert d 0 i
in (e, j) }, ([<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x722b91bb37f0>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb0b70>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb25b0>]))
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/paul.wang/catalyst_new/catalyst/multi_tape.py", line 144, in <module>
circuit = qjit(circuit)
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 376, in qjit
return QJIT(fn, CompileOptions(**kwargs))
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
output = func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 443, in __init__
self.aot_compile()
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 486, in aot_compile
self.mlir_module, self.mlir = self.generate_ir()
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
return fn(*args, **kwargs)
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 621, in generate_ir
mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_tracer.py", line 558, in lower_jaxpr_to_mlir
mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/lowering.py", line 72, in jaxpr_to_mlir
module, context = custom_lower_jaxpr_to_module(
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_extras/lowering.py", line 140, in custom_lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
return kept_rules[0](ctx, *rule_args, **rule_kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_primitives.py", line 608, in _func_lowering
func_op = _func_def_lowering(ctx.module_context, fn, call_jaxpr, name_stack=ctx.name_stack)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_primitives.py", line 566, in _func_def_lowering
func_op = mlir.lower_jaxpr_to_fun(ctx, fn.__name__, call_jaxpr, tuple(), name_stack=name_stack)
File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
return kept_rules[0](ctx, *rule_args, **rule_kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jax_primitives.py", line 2022, in _for_loop_lowering
out, _ = mlir.jaxpr_subcomp(
File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1574, in jaxpr_subcomp
assert len(args) == len(jaxpr.invars), (jaxpr, args)
AssertionError: ({ lambda ; a:i64[] b:i64[] c:AbstractQreg() d:AbstractQreg(). let
e:i64[] = add b 2
f:AbstractQbit() = qextract c 0
g:AbstractQbit() = qinst[
adjoint=False
ctrl_len=0
op=RX
params_len=1
qubits_len=1
] f 3.14
_:AbstractQreg() = qinsert c 0 g
h:AbstractQbit() = qextract d 0
i:AbstractQbit() = qinst[
adjoint=False
ctrl_len=0
op=RX
params_len=1
qubits_len=1
] h 3.14
j:AbstractQreg() = qinsert d 0 i
in (e, j) }, ([<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x722b91bb37f0>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb0b70>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x722b91bb25b0>]))