Support for `SProd`, `Prod`, and `Sum`
Sum
Sum should be handled the same way as Hamiltonian and LinearCombination, which was partially addressed in https://github.com/amazon-braket/amazon-braket-pennylane-plugin-python/pull/252, but the same treatment should be applied to translate_result_type and translate_result in translation.py as well.
Note: Sum.ops is deprecated, so instead of measurement.obs.ops, do _, ops = measurement.obs.terms(), and then use ops.
SProd and Prod
Since SProd and Prod could be nested, they are not guaranteed to be single-term observables. For example, an SProd could be 0.1 * (qml.Z(0) + qml.X(1)), in which case it's actually a Sum. Similarly, a Prod could be qml.Z(0) @ (qml.X(0) + qml.Y(1)).
This means that the same treatment for Hamiltonian, LinearCombination and Sum should extend to SProd and Prod as well, including _translate_observable, which should register Sum, SProd and Prod all under the same dispatch function as Hamiltonian, which uses H.terms().
Caveat: Prod.terms() will resolve to itself if the Prod only contains one term. For example:
>>> op = qml.X(0) @ qml.Y(1)
>>> op.terms()
([1.0], [X(0) @ Y(1)])
This may result in infinite recursion in _translate_observable, so a base case should be added to return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in H.operands]) if H is a Prod with a single term.
Note: The terms() function will unwrap any nested structures but also simplify the observable. For example:
>>> op = qml.X(0) @ qml.I(1)
>>> op.terms()
([1.0], [X(0)])
This will create a mismatch between the number of targets in the translated observable and the original observable. We do plan on addressing this issue in PennyLane and have terms() recursively unwraps the observable without doing any simplification, but for now, in _pl_to_braket_circuit, do not use circuit.measurements directly, instead do something like
measurements = []
for mp in circuit.measurements:
obs = mp.obs
if isinstance(obs, (Hamiltonian, LinearCombination, Sum, SProd, Prod)):
obs = obs.simplify()
mp = type(mp)(obs)
measurements.append(mp)
Then use measurements instead of circuit.measurements from this point on. The list of simplified measurements should also be passed into _apply_gradient_result_type and used there.
Device
Now since SProd, Prod, and Sum all could be nested, multi-term observables, they should be removed from the list of supported observables and added back if no shots are present:
@property
def observables(self) -> frozenset[str]:
base_observables = frozenset(super().observables - {"Prod", "SProd", "Sum"})
# Amazon Braket only supports coefficients and multiple terms when shots==0
if not self.shots:
return base_observables.union({"Hamiltonian", "LinearCombination", "Prod", "SProd", "Sum"})
return base_observables
Hi @astralcai, thank you for raising this. I shall start looking into a fix.
This means that the same treatment for
Hamiltonian,LinearCombinationandSumshould extend toSProdandProdas well, including_translate_observable, which should registerSum,SProdandProdall under the same dispatch function asHamiltonian, which usesH.terms().
The current _translate_observable implementations for Sum, SProd and Prod recursively call _translate_observable on their operands:
@_translate_observable.register
def _(t: qml.operation.Tensor):
return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.obs])
@_translate_observable.register
def _(t: qml.ops.Prod):
return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands])
@_translate_observable.register
def _(t: qml.ops.SProd):
return t.scalar * _translate_observable(t.base)
Shouldn't that take care of the nesting problem?
measurements = []
for mp in circuit.measurements:
obs = mp.obs
if isinstance(obs, (Hamiltonian, LinearCombination, Sum, SProd, Prod)):
obs = obs.simplify()
mp = type(mp)(obs)
measurements.append(mp)
I'm noticing that simplify alters the order of operands (at least in Prod); is this intentional?
This means that the same treatment for
Hamiltonian,LinearCombinationandSumshould extend toSProdandProdas well, including_translate_observable, which should registerSum,SProdandProdall under the same dispatch function asHamiltonian, which usesH.terms().The current
_translate_observableimplementations forSum,SProdandProdrecursively call_translate_observableon their operands:@_translate_observable.register def _(t: qml.operation.Tensor): return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.obs]) @_translate_observable.register def _(t: qml.ops.Prod): return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands]) @_translate_observable.register def _(t: qml.ops.SProd): return t.scalar * _translate_observable(t.base)Shouldn't that take care of the nesting problem?
It should, but as I recall it didn't. I was looking into it some time ago and couldn't make it work, that's why I suggested using the same approach for all potential multi-term observables. You can give it a try. I don't remember what the issue was exactly, but I believe it has something to do with the braket backend unable to parse scalar products.
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/interpreter.py:545: in _
parsed = self.context.parse_pragma(node.command)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/program_context.py:455: in parse_pragma
return parse_braket_pragma(pragma_body, self.qubit_mapping)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/braket_pragmas.py:216: in parse_braket_pragma
visited = BraketPragmaNodeVisitor(qubit_table).visit(tree)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:34: in visit
return tree.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:861: in accept
return visitor.visitBraketPragma(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParserVisitor.py:14: in visitBraketPragma
return self.visitChildren(ctx)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:44: in visitChildren
childResult = c.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:1226: in accept
return visitor.visitBraketResultPragma(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParserVisitor.py:39: in visitBraketResultPragma
return self.visitChildren(ctx)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:44: in visitChildren
childResult = c.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:1290: in accept
return visitor.visitResultType(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParserVisitor.py:44: in visitResultType
return self.visitChildren(ctx)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:44: in visitChildren
childResult = c.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:1867: in accept
return visitor.visitObservableResultType(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/braket_pragmas.py:98: in visitObservableResultType
observables, targets = self.visit(ctx.observable())
E TypeError: cannot unpack non-iterable NoneType object
----------------------------- Captured stderr call -----------------------------
line 1:26 mismatched input '0.1' expecting {'x', 'y', 'z', 'i', 'h', 'hermitian'}
This occured when trying to parse the scalar product of an observable. See this run: https://github.com/PennyLaneAI/plugin-test-matrix/actions/runs/9018042395/job/24777766316
measurements = [] for mp in circuit.measurements: obs = mp.obs if isinstance(obs, (Hamiltonian, LinearCombination, Sum, SProd, Prod)): obs = obs.simplify() mp = type(mp)(obs) measurements.append(mp)I'm noticing that
simplifyalters the order of operands (at least inProd); is this intentional?
Simplify does not preserve the original order of operands.
Sorry for the delay; I finally managed to return to this, and I think I've found the actual issues. Looking at the device test run in #264, we observe two types of failures:
TypeError: cannot unpack non-iterable NoneType object
mismatched input '0.1' expecting {'x', 'y', 'z', 'i', 'h', 'hermitian'}
This is due to attempting to run Braket Sum observables on the local simulator, which does not support them. This is fixed by your suggestion of expanding the treatment of Hamiltonians to CompositeOp and SProd.
ValueError: Sum observable's target shape must be a nested list where each term's target length is equal to the observable term's qubits count.
This is because we pass in the MeasurementProcess' wires, which is a flat list, into translate_result_type, which expects a list of lists for its targets:
https://github.com/amazon-braket/amazon-braket-pennylane-plugin-python/blob/17dae39443ec30b262c61708a1dab6f10ed78953/src/braket/pennylane_plugin/braket_device.py#L257-L260
This is fixed by mapping the wires of the MeasurementProcess itself and using those wires instead of passing in wires separately.
Fixed in #275