catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[BUG] Classical pre-processing not working when using `grad` with enzyme

Open josh146 opened this issue 2 years ago • 3 comments

The following code, with no classical post-processing on the function we are computing the gradient of, works correctly:

@qjit
def f(x):
    @qml.qnode(dev)
    def g(y):
        qml.RX(y, wires=0)
        return qml.expval(qml.PauliZ(0))
    return grad(lambda y: g(y) ** 2)(x)
>>> f(0.4)
array(-0.99994172)

However, if we introduce classical pre-processing on the QNode argument, this no longer works:

@qjit
def f(x):
    @qml.qnode(dev)
    def g(y):
        qml.RX(y, wires=0)
        return qml.expval(qml.PauliZ(0))
    return grad(lambda y: g(jnp.cos(y)) ** 2)(x)
>>> f(0.4)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

[<ipython-input-17-86870092fac7>](https://localhost:8080/#) in <cell line: 1>()
----> 1 f(0.4)

4 frames

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    645             return self.user_function(*args, **kwargs)
    646 
--> 647         function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
    648             self.compiled_function, *args
    649         )

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in _ensure_real_arguments_and_formal_parameters_are_compatible(self, function, *args)
    620             if not self.compiling_from_textual_ir:
    621                 self.mlir_module = self.get_mlir(*r_sig)
--> 622             function = self.compile()
    623         else:
    624             assert next_action == TypeCompatibility.CAN_SKIP_PROMOTION

[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in compile(self)
    579             qfunc_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
    580 
--> 581             shared_object, llvm_ir, inferred_func_data = self.compiler.run(
    582                 self.mlir_module, pipelines=self.compile_options.pipelines
    583             )

[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run(self, mlir_module, *args, **kwargs)
    399         """
    400 
--> 401         return self.run_from_ir(
    402             mlir_module.operation.get_asm(
    403                 binary=False, print_generic_op_form=False, assume_verified=True

[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run_from_ir(self, ir, module_name, pipelines, lower_to_llvm)
    356             print(f"[LIB] Running compiler driver in {workspace}", file=self.options.logfile)
    357 
--> 358         compiler_output = run_compiler_driver(
    359             ir,
    360             workspace,

RuntimeError: Compilation failed:
operand #0 does not dominate this use

josh146 avatar Sep 30 '23 23:09 josh146

Note that this does work correctly with finite-diff:

@qjit
def f(x):
    @qml.qnode(dev)
    def g(y):
        qml.RX(y, wires=0)
        return qml.expval(qml.PauliZ(0))
    return grad(lambda y: g(jnp.sin(y)) ** 2, method="fd")(x)
>>> f(0.4)
array(-0.64700111)

josh146 avatar Oct 01 '23 01:10 josh146

@erick-xanadu points out that the following works:

device = qml.device("lightning.qubit", wires=1)

@qjit(keep_intermediate=True)
def f(x):
    @qml.qnode(device)
    def g(y):
        z = jnp.cos(y)
        qml.RX(z, wires=0)
        return qml.expval(qml.PauliZ(0))

    def post(y):
        return g(y) ** 2

    return grad(post)(x)

print(f(0.4))

There seems to be an implicit assumption that classical pre-processing can only occur inside a QNode.

josh146 avatar Oct 06 '23 14:10 josh146

Note, I have reproduced the error from the description using the version of Catalyst where the verbose error logging is added. Here is the report

sergei-mironov avatar Oct 10 '23 13:10 sergei-mironov