catalyst icon indicating copy to clipboard operation
catalyst copied to clipboard

[BUG] When using transforms with AutoGraph, there is unexpected behaviour where transforms may not be properly applied.

Open josh146 opened this issue 1 year ago • 2 comments

For example, consider the following:

@qml.transform
def my_quantum_transform(tape):
    
    new_operations = []

    for op in tape:
      if op.name == "RX":
          if op.parameters[0] < 0:
              new_operations.append(qml.RY(-op.parameters[0], wires=op.wires))
      else:
          new_operations.append(op)

    new_tape = type(tape)(new_operations, tape.measurements, shots=tape.shots)

    def post_processing_fn(results):
        return results[0]

    return [new_tape], post_processing_fn

@qml.qjit(autograph=True)
@my_quantum_transform
@qml.qnode(dev)
def circuit(x):
    qml.RY(x, wires=0)
    qml.RX(x, wires=0)
    return qml.expval(qml.PauliZ(0))

If we run this QNode with autograph=False, we will get an error, as the if statement will be attempting to compute a boolean given a JAX tracer.

With autograph=True, execution will not error, however the results (and the corresponding JAXPR) indicate that the transform has not been applied.

josh146 avatar Jan 05 '24 19:01 josh146

Update: rewriting the transform to use queuing also exhibits the same behaviour:

@qml.transform
def my_quantum_transform(tape):
    
    new_operations = []

    with qml.queuing.AnnotatedQueue() as q:
        for op in tape:
            if op.name == "RX":
                if op.parameters[0] < 0:
                    qml.RY(-op.parameters[0], wires=op.wires)
                else:
                    qml.apply(op)
            else:
                qml.apply(op)

    new_tape = qml.tape.QuantumTape.from_queue(q, shots=tape.shots)

    def post_processing_fn(results):
        return results[0]

    return [new_tape], post_processing_fn

josh146 avatar Jan 05 '24 20:01 josh146

Thanks to @albi3ro, we've discovered that the transform is being fully skipped when AutoGraph is enabled:

@qml.transform
def my_quantum_transform(tape):
    return NotImplementedError

@qml.qjit(autograph=True)
def f(x):
    @my_quantum_transform
    @qml.qnode(dev)
    def circuit(x):
        qml.RY(x, wires=0)
        qml.RX(x, wires=0)
        return qml.expval(qml.PauliZ(0))
    return circuit(x)

>>> f(-0.5)
array(0.778)

josh146 avatar Jan 05 '24 20:01 josh146

@tzunghanjuang can this issue now be closed?

josh146 avatar Jul 20 '24 01:07 josh146

@josh146 Yes. I have closed this.

tzunghanjuang avatar Jul 22 '24 13:07 tzunghanjuang