catalyst
catalyst copied to clipboard
grad does not work when using dynamic one-shot
Issue description
using grad does not work when using dynamic one-shot.
- Actual behavior: Crash happens in the following code:
@qml.qnode(dev, diff_method="best", mcm_method="one-shot")
def f(x: float):
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(wires=0))
@qjit
def grad_f(x):
return grad(f, method="auto")(x)
print(grad_f(1.0))
which crashes with the following message:
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/test-issues.py", line 40, in <module>
print(grad_f(1.0))
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 457, in __call__
requires_promotion = self.jit_compile(args, **kwargs)
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 528, in jit_compile
self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
return fn(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 610, in capture
jaxpr, out_type, treedef = trace_to_jaxpr(
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 537, in trace_to_jaxpr
jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 608, in fn_with_transform_named_sequence
return self.user_function(*args, **kwargs)
File "/Users/mehrdad.malek/tmp/test-issues.py", line 38, in grad_f
return grad(f, method="auto")(x)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 688, in __call__
results = grad_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/test-issues.py", line 40, in <module>
print(grad_f(1.0))
^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 457, in __call__
requires_promotion = self.jit_compile(args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 532, in jit_compile
self.mlir_module, self.mlir = self.generate_ir()
^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 625, in generate_ir
mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 564, in lower_jaxpr_to_mlir
mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 72, in jaxpr_to_mlir
module, context = custom_lower_jaxpr_to_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 140, in custom_lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
return kept_rules[0](ctx, *rule_args, **rule_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_primitives.py", line 711, in _grad_lowering
attr = ir.DenseElementsAttr.get(nparray, type=const_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
-
Expected behavior: Same circuit without mcm_method="one-shot"
dev = qml.device('lightning.qubit', wires=1, shots=5) @qml.qnode(dev, diff_method="best") def g(x: float): qml.RX(x, wires=0) return qml.expval(qml.PauliZ(wires=0)) @qjit def grad_g(x): return grad(g, method="auto")(x)returns
-0.4 -
Reproduces how often: 100%
-
System information:
Name: PennyLane Version: 0.38.0.dev24 Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network. Home-page: https://github.com/PennyLaneAI/pennylane Author: Author-email: License: Apache License 2.0 Location: /Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions Required-by: amazon-braket-pennylane-plugin, PennyLane-Catalyst, pennylane-qrack, PennyLane_Lightning, PennyLane_Lightning_Kokkos Platform info: macOS-14.6.1-arm64-arm-64bit Python version: 3.12.4 Numpy version: 1.26.4 Scipy version: 1.12.0
Same crash exists for value_and_grad:
def workflow(x: float):
@qml.qnode(qml.device("lightning.qubit", wires=3, shots=10), mcm_method="one-shot")
def circuit1():
qml.CNOT(wires=[0, 1])
qml.RX(0, wires=[2])
return qml.probs() # This is [1, 0, 0, ...]
return x * (circuit1()[0])
result2 = qjit(value_and_grad(workflow))(3.0)
Same crash happens when using jvp and vjp:
x, t = (
[-0.1, 0.5],
[0.1, 0.33],
)
def circuit_rx(x1, x2):
"""A test quantum function"""
qml.RX(x1, wires=0)
qml.RX(x2, wires=0)
return qml.expval(qml.PauliY(0))
@qjit
def C_workflow():
f = qml.QNode(circuit_rx, device=dev, mcm_method="one-shot")
return C_jvp(f, x, t, method="auto", argnums=list(range(len(x))))
r1 = C_workflow()
the traceback is a little different but essentially the same problem:
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/test-issues.py", line 34, in <module>
@qjit
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 377, in qjit
return QJIT(fn, CompileOptions(**kwargs))
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
output = func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 445, in __init__
self.aot_compile()
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 483, in aot_compile
self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
return fn(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 610, in capture
jaxpr, out_type, treedef = trace_to_jaxpr(
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 537, in trace_to_jaxpr
jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 608, in fn_with_transform_named_sequence
return self.user_function(*args, **kwargs)
File "/Users/mehrdad.malek/tmp/test-issues.py", line 37, in C_workflow
return C_jvp(f, x, t, method="auto", argnums=list(range(len(x))))
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 488, in jvp
results = jvp_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: unimplemented array format conversion from format: ?
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/test-issues.py", line 34, in <module>
@qjit
^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 377, in qjit
return QJIT(fn, CompileOptions(**kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 445, in __init__
self.aot_compile()
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 488, in aot_compile
self.mlir_module, self.mlir = self.generate_ir()
^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 625, in generate_ir
mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 564, in lower_jaxpr_to_mlir
mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 72, in jaxpr_to_mlir
module, context = custom_lower_jaxpr_to_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 140, in custom_lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
return kept_rules[0](ctx, *rule_args, **rule_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_primitives.py", line 821, in _jvp_lowering
StableHLOConstantOp(ir.DenseElementsAttr.get(np.asarray(const))).results
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: unimplemented array format conversion from format: ?
The following tests have already been added to catalyst and are marked xfail or skip. Any solution to this issue should pass these tests at catalyst/frontend/test/pytest/test_mid_circuit_measurement.py :
test_mcm_method_with_grad
test_mcm_method_with_value_and_grad
test_mcm_method_with_jvp
test_mcm_method_with_jvp
Closing this issue for now, as it is not reproducible with PennyLane (v0.43.0.dev69) and Catalyst (v0.13.0.dev69). Gradient-related operations are currently unsupported due to known issues with Enzyme. I have created two new issues (#2087, #2088) to track related bugs in dynamic one-shot as part of this investigation.
Gradient-related operations are currently unsupported due to known issues with Enzyme
@lazypanda10117 what are you referring to here? E.g., is this a new, wider gradient related bug, or an existing known bug?
Gradient-related operations are currently unsupported due to known issues with Enzyme
@lazypanda10117 what are you referring to here? E.g., is this a new, wider gradient related bug, or an existing known bug?
The main issue we ran into instead of the error Mehrdad encountered was the longstanding enzyme issue with grad and vmap/for loops.