`insertion_sort_transformer` causes "Measurement key missing" with classical control
Describe the issue
Using insertion_sort_transformer on a circuit that contains a CircuitOperation controlled by a measurement key results in a ValueError. The error disappears when using defer_measurements, or if the transformer is not applied.
Explain how to reproduce the bug or problem
To reproduce:
import cirq
from cirq.transformers import *
q = cirq.LineQubit.range(3)
circuit = cirq.Circuit()
sub_circuit = cirq.Circuit()
sub_circuit.append(cirq.X(q[0]).controlled_by(q[1]))
sub_op = cirq.CircuitOperation(sub_circuit.freeze())
circuit.append(cirq.measure(q[2], key="c"))
circuit.append(sub_op.with_classical_controls("c"))
circuit.append(cirq.measure(q, key="m"))
# circuit = defer_measurements(circuit)
circuit = insertion_sort_transformer(circuit)
simulator = cirq.Simulator()
result = simulator.run(circuit, repetitions=500)
print(result.histogram(key='m'))
Traceback (most recent call last):
File "C:\Users\temp_test.py", line 22, in
Tell us the version of Cirq where this happens
1.6.0.dev20250702012506
Confirmed. It looks like the insertion sort transformer skips any measurement key commutativity checks if the qubits are disjoint. https://github.com/quantumlib/cirq/blob/9c376c4e72ec924ee909a6c31147d703e69f1078/cirq-core/cirq/transformers/insertion_sort.py#L57. If I understand the intent correctly it looks like all those special cases are just for performance, not required for logic? If so, I think the best option would be to simply remove them. commutes itself also has similar optimizations in place, but includes measurement key checks in its implementation, so it should be able to guarantee correctness but without being much of a perf hit.
cc @NoureldinYosri
Here's also a simpler repro that only requires two gates and no subcircuits.
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q1, key="c"),
cirq.X(q0).with_classical_controls("c")
)
print(repr(circuit))
circuit = insertion_sort_transformer(circuit)
print(repr(circuit))
cirq.Simulator().run(circuit)