catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

grad does not work when using dynamic one-shot

Open mehrdad2m opened this issue 1 year ago • 3 comments
trafficstars

Issue description

using grad does not work when using dynamic one-shot.

  • Actual behavior: Crash happens in the following code:
@qml.qnode(dev, diff_method="best", mcm_method="one-shot")
def f(x: float):
    qml.RX(x, wires=0)
    return qml.expval(qml.PauliZ(wires=0))

@qjit
def grad_f(x):
    return grad(f, method="auto")(x)

print(grad_f(1.0))

which crashes with the following message:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 40, in <module>
    print(grad_f(1.0))
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 457, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 528, in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 610, in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 537, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 608, in fn_with_transform_named_sequence
    return self.user_function(*args, **kwargs)
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 38, in grad_f
    return grad(f, method="auto")(x)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 688, in __call__
    results = grad_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 40, in <module>
    print(grad_f(1.0))
          ^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 457, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 532, in jit_compile
    self.mlir_module, self.mlir = self.generate_ir()
                                  ^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 625, in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 564, in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 72, in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 140, in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
                           ^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_primitives.py", line 711, in _grad_lowering
    attr = ir.DenseElementsAttr.get(nparray, type=const_type)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.
  • Expected behavior: Same circuit without mcm_method="one-shot"

    dev = qml.device('lightning.qubit', wires=1, shots=5)
    @qml.qnode(dev, diff_method="best")
    def g(x: float):
       qml.RX(x, wires=0)
       return qml.expval(qml.PauliZ(wires=0))
    @qjit
    def grad_g(x):
       return grad(g, method="auto")(x)
    

    returns

     -0.4
    
  • Reproduces how often: 100%

  • System information:

    Name: PennyLane
    Version: 0.38.0.dev24
    Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
    Home-page: https://github.com/PennyLaneAI/pennylane
    Author: 
    Author-email: 
    License: Apache License 2.0
    Location: /Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages
    Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, toml, typing-extensions
    Required-by: amazon-braket-pennylane-plugin, PennyLane-Catalyst, pennylane-qrack, PennyLane_Lightning, PennyLane_Lightning_Kokkos
    
    Platform info:           macOS-14.6.1-arm64-arm-64bit
    Python version:          3.12.4
    Numpy version:           1.26.4
    Scipy version:           1.12.0
    

mehrdad2m avatar Sep 03 '24 13:09 mehrdad2m

Same crash exists for value_and_grad:

def workflow(x: float):
    @qml.qnode(qml.device("lightning.qubit", wires=3, shots=10), mcm_method="one-shot")
    def circuit1():
        qml.CNOT(wires=[0, 1])
        qml.RX(0, wires=[2])
        return qml.probs()  # This is [1, 0, 0, ...]

    return x * (circuit1()[0])
result2 = qjit(value_and_grad(workflow))(3.0)

mehrdad2m avatar Sep 03 '24 13:09 mehrdad2m

Same crash happens when using jvp and vjp:

x, t = (
    [-0.1, 0.5],
    [0.1, 0.33],
)

def circuit_rx(x1, x2):
    """A test quantum function"""
    qml.RX(x1, wires=0)
    qml.RX(x2, wires=0)
    return qml.expval(qml.PauliY(0))

@qjit
def C_workflow():
    f = qml.QNode(circuit_rx, device=dev, mcm_method="one-shot")
    return C_jvp(f, x, t, method="auto", argnums=list(range(len(x))))


r1 = C_workflow()

the traceback is a little different but essentially the same problem:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 34, in <module>
    @qjit
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 377, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 445, in __init__
    self.aot_compile()
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 483, in aot_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 610, in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 537, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 555, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 608, in fn_with_transform_named_sequence
    return self.user_function(*args, **kwargs)
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 37, in C_workflow
    return C_jvp(f, x, t, method="auto", argnums=list(range(len(x))))
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 488, in jvp
    results = jvp_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: unimplemented array format conversion from format: ?

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/test-issues.py", line 34, in <module>
    @qjit
     ^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 377, in qjit
    return QJIT(fn, CompileOptions(**kwargs))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 65, in wrapper_exit
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 445, in __init__
    self.aot_compile()
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 488, in aot_compile
    self.mlir_module, self.mlir = self.generate_ir()
                                  ^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 625, in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 564, in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 72, in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 140, in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
                           ^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_primitives.py", line 821, in _jvp_lowering
    StableHLOConstantOp(ir.DenseElementsAttr.get(np.asarray(const))).results
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: unimplemented array format conversion from format: ?

mehrdad2m avatar Sep 03 '24 13:09 mehrdad2m

The following tests have already been added to catalyst and are marked xfail or skip. Any solution to this issue should pass these tests at catalyst/frontend/test/pytest/test_mid_circuit_measurement.py :

test_mcm_method_with_grad
test_mcm_method_with_value_and_grad
test_mcm_method_with_jvp
test_mcm_method_with_jvp

mehrdad2m avatar Sep 03 '24 13:09 mehrdad2m

Closing this issue for now, as it is not reproducible with PennyLane (v0.43.0.dev69) and Catalyst (v0.13.0.dev69). Gradient-related operations are currently unsupported due to known issues with Enzyme. I have created two new issues (#2087, #2088) to track related bugs in dynamic one-shot as part of this investigation.

lazypanda10117 avatar Oct 03 '25 17:10 lazypanda10117

Gradient-related operations are currently unsupported due to known issues with Enzyme

@lazypanda10117 what are you referring to here? E.g., is this a new, wider gradient related bug, or an existing known bug?

josh146 avatar Oct 04 '25 15:10 josh146

Gradient-related operations are currently unsupported due to known issues with Enzyme

@lazypanda10117 what are you referring to here? E.g., is this a new, wider gradient related bug, or an existing known bug?

The main issue we ran into instead of the error Mehrdad encountered was the longstanding enzyme issue with grad and vmap/for loops.

dime10 avatar Oct 06 '25 14:10 dime10