catalyst
catalyst copied to clipboard
[BUG] Classical pre-processing not working when using `grad` with enzyme
The following code, with no classical post-processing on the function we are computing the gradient of, works correctly:
@qjit
def f(x):
@qml.qnode(dev)
def g(y):
qml.RX(y, wires=0)
return qml.expval(qml.PauliZ(0))
return grad(lambda y: g(y) ** 2)(x)
>>> f(0.4)
array(-0.99994172)
However, if we introduce classical pre-processing on the QNode argument, this no longer works:
@qjit
def f(x):
@qml.qnode(dev)
def g(y):
qml.RX(y, wires=0)
return qml.expval(qml.PauliZ(0))
return grad(lambda y: g(jnp.cos(y)) ** 2)(x)
>>> f(0.4)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-17-86870092fac7>](https://localhost:8080/#) in <cell line: 1>()
----> 1 f(0.4)
4 frames
[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
645 return self.user_function(*args, **kwargs)
646
--> 647 function, args = self._ensure_real_arguments_and_formal_parameters_are_compatible(
648 self.compiled_function, *args
649 )
[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in _ensure_real_arguments_and_formal_parameters_are_compatible(self, function, *args)
620 if not self.compiling_from_textual_ir:
621 self.mlir_module = self.get_mlir(*r_sig)
--> 622 function = self.compile()
623 else:
624 assert next_action == TypeCompatibility.CAN_SKIP_PROMOTION
[/usr/local/lib/python3.10/dist-packages/catalyst/compilation_pipelines.py](https://localhost:8080/#) in compile(self)
579 qfunc_name = str(self.mlir_module.body.operations[0].name).replace('"', "")
580
--> 581 shared_object, llvm_ir, inferred_func_data = self.compiler.run(
582 self.mlir_module, pipelines=self.compile_options.pipelines
583 )
[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run(self, mlir_module, *args, **kwargs)
399 """
400
--> 401 return self.run_from_ir(
402 mlir_module.operation.get_asm(
403 binary=False, print_generic_op_form=False, assume_verified=True
[/usr/local/lib/python3.10/dist-packages/catalyst/compiler.py](https://localhost:8080/#) in run_from_ir(self, ir, module_name, pipelines, lower_to_llvm)
356 print(f"[LIB] Running compiler driver in {workspace}", file=self.options.logfile)
357
--> 358 compiler_output = run_compiler_driver(
359 ir,
360 workspace,
RuntimeError: Compilation failed:
operand #0 does not dominate this use
Note that this does work correctly with finite-diff:
@qjit
def f(x):
@qml.qnode(dev)
def g(y):
qml.RX(y, wires=0)
return qml.expval(qml.PauliZ(0))
return grad(lambda y: g(jnp.sin(y)) ** 2, method="fd")(x)
>>> f(0.4)
array(-0.64700111)
@erick-xanadu points out that the following works:
device = qml.device("lightning.qubit", wires=1)
@qjit(keep_intermediate=True)
def f(x):
@qml.qnode(device)
def g(y):
z = jnp.cos(y)
qml.RX(z, wires=0)
return qml.expval(qml.PauliZ(0))
def post(y):
return g(y) ** 2
return grad(post)(x)
print(f(0.4))
There seems to be an implicit assumption that classical pre-processing can only occur inside a QNode.
Note, I have reproduced the error from the description using the version of Catalyst where the verbose error logging is added. Here is the report