catalyst
catalyst copied to clipboard
[WIP] Allow `catalyst.cond` to take in branch functions with arguments
Context:
There is a restriction on catalyst.cond, that the conditional branch functions can never have arguments. This feature should be allowed, especially as we move towards FTQC.
Description of the Change:
The strategy used in #1232 to allow cond to take in pennylane gates is extended to apply to all callables with arguments.
Benefits:
catalyst.cond can take in branch functions with arguments, e.g
@qml.qnode(qml.device("lightning.qubit", wires=2))
def ref_func():
qml.PauliX(wires=1) # |01>
m0 = qml.measure(0) # will measure 0
def true_fn(wire):
qml.PauliX(wires=wire)
def false_fn(wire): # will come here
qml.RX(1.23, wires=wire+1)
qml.cond(m0 == 1, true_fn, false_fn)(0)
return qml.probs()
print("ref: ", ref_func())
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def func():
qml.PauliX(wires=1) # |01>
m0 = catalyst.measure(0) # will measure 0
@catalyst.cond(m0 == 1)
def conditional(wire):
qml.PauliX(wires=wire)
@conditional.otherwise
def false_fn(wire): # will come here
qml.RX(1.23, wires=wire+1)
conditional(0)
return qml.probs()
print("cat: ", func())
ref: [0.33288114 0.66711886 0. 0. ]
cat: [0.33288114 0.66711886 0. 0. ]