catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[WIP] Allow `catalyst.cond` to take in branch functions with arguments

Open paul0403 opened this issue 8 months ago • 7 comments

Context: There is a restriction on catalyst.cond, that the conditional branch functions can never have arguments. This feature should be allowed, especially as we move towards FTQC.

Description of the Change: The strategy used in #1232 to allow cond to take in pennylane gates is extended to apply to all callables with arguments.

Benefits: catalyst.cond can take in branch functions with arguments, e.g

@qml.qnode(qml.device("lightning.qubit", wires=2))
def ref_func():
    qml.PauliX(wires=1) # |01>
    m0 = qml.measure(0)  # will measure 0

    def true_fn(wire):
        qml.PauliX(wires=wire)

    def false_fn(wire): # will come here
        qml.RX(1.23, wires=wire+1)

    qml.cond(m0 == 1, true_fn, false_fn)(0)

    return qml.probs()


print("ref: ", ref_func())


@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def func():
    qml.PauliX(wires=1) # |01>
    m0 = catalyst.measure(0)  # will measure 0

    @catalyst.cond(m0 == 1)
    def conditional(wire):
        qml.PauliX(wires=wire)

    @conditional.otherwise
    def false_fn(wire): # will come here
        qml.RX(1.23, wires=wire+1)

    conditional(0)

    return qml.probs()

print("cat: ", func())
ref:  [0.33288114 0.66711886 0.         0.        ]
cat:  [0.33288114 0.66711886 0.         0.        ]

paul0403 avatar Feb 20 '25 20:02 paul0403