[BUG][Return-types] Sample does not return the right type with interfaces
Expected behavior
With the new return types activated qml.enable_return(). We expect Sample to return a tensor related to the desired interfaces but it currently only returns numpy array instead of torch.tensor, tf.Variable or jax.array.
Actual behavior
import pennylane as qml
import tensorflow as tf
qml.enable_return()
measurements = [qml.sample(qml.PauliZ(0)), qml.sample(wires=[0]), qml.sample(wires=[0, 1])]
devices = ["default.qubit.tf", "default.mixed"]
shots = 100
for m in measurements:
for d in devices:
dev = qml.device(d, wires=2, shots=shots)
def circuit(x):
qml.Hadamard(wires=[0])
qml.CRX(x, wires=[0, 1])
return qml.apply(m)
qnode = qml.QNode(circuit, dev, diff_method=None)
res = qnode(tf.Variable(0.5))
print(type(res))
import torch
measurements = [qml.sample(qml.PauliZ(0)), qml.sample(wires=[0]), qml.sample(wires=[0, 1])]
devices = ["default.qubit.torch", "default.mixed"]
shots = 100
for m in measurements:
for d in devices:
dev = qml.device(d, wires=2, shots=shots)
def circuit(x):
qml.Hadamard(wires=[0])
qml.CRX(x, wires=[0, 1])
return qml.apply(m)
qnode = qml.QNode(circuit, dev, diff_method=None)
res = qnode(torch.tensor(0.5))
print(type(res))
devices = ["default.qubit.jax", "default.mixed"]
shots = 100
for m in measurements:
for d in devices:
dev = qml.device(d, wires=2, shots=shots)
def circuit(x):
qml.Hadamard(wires=[0])
qml.CRX(x, wires=[0, 1])
return qml.apply(m)
qnode = qml.QNode(circuit, dev, diff_method=None)
res = qnode(torch.tensor(0.5))
print(type(res))
<class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> <class 'numpy.ndarray'> WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) <class 'jaxlib.xla_extension.DeviceArray'> <class 'numpy.ndarray'> <class 'jaxlib.xla_extension.DeviceArray'> <class 'numpy.ndarray'> <class 'jaxlib.xla_extension.DeviceArray'> <class 'numpy.ndarray'>
Additional information
No response
Source code
No response
Tracebacks
No response
System information
-
Existing GitHub issues
- [X] I have searched existing GitHub issues to make sure the issue does not already exist.