User-facing call to parameter shift differentiator
Our use case involves using a Sample layer to obtain bitstrings, then computing a loss as a function of these bitstrings. We wish to obtain gradients of the loss with respect to the parameters of the circuit given to the Sample layer; we want to use the Parameter Shift rule to calculate these gradients. What would it take to break out the parameter shift code into a user-facing call?
We have example code for how we would like to use Sample layer with gradients:
def energy(logits, samples): # energy of a bernoulli with logit param
samples = tf.cast(samples, dtype=tf.float32)
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(samples,
logits))
def sample(phi, n):
circuit = cirq.Circuit(cirq.rx(sympy.Symbol('rx'))(cirq.GridQubit(0, 0)))
samples = tfq.layers.Sample()(tfq.convert_to_tensor([circuit]),
symbol_names=['rx'],
symbol_values=tf.expand_dims(phi, 0),
repetitions=n.numpy()) # <- hacked
return tf.cast(samples, tf.bool)
@tf.custom_gradient
def loss(phis, num_samples):
def grad(dphis):
# TODO!!! Parameter shift gradient
return None, None
return bare_loss(phis, num_samples), grad
def bare_loss(phis, num_samples):
samples = sample(phis, num_samples)
energies = tf.TensorArray(tf.dtypes.float32, num_samples)
for i in tf.range(num_samples):
this_energy = energy(tf.constant([5.0]), samples[0][i])
energies = energies.write(i, tf.expand_dims(this_energy, 0))
energies = energies.stack()
return tf.reduce_mean(energies)
test_phi = tf.constant([7.0])
num_samples = tf.constant(3)
with tf.GradientTape() as g:
g.watch(test_phi)
loss_val = loss(test_phi, num_samples)
phi_grad = g.gradient(loss_val, test_phi)
print(loss_val)
print("phi grad: {}".format(phi_grad))
We have a notebook we started to prototype solutions: https://colab.research.google.com/drive/1h6aUeQ6rLVhd0sE600kA0X6hDSAazw4o#scrollTo=ZXpGyhYSNpFK
@geoffroeder
@zaqqwerty This issue has sat for a long time, which is regrettable. For purposes of planning work and doing repository housekeeping, could you let us know what the status of this is?
Due to time constraints, we have to prioritize more aggressively. I'm closing this as not essential, based on both the description and the lack of activity since 2020.