pyro
pyro copied to clipboard
[FR] gp.util.train to return a trace over multiple quantities rather than just loss
Currently, gp.util.train return the trace over loss (loss v/s iterations)
losses = gp.util.train(vsgp, num_steps=num_steps)
If one needs to get a trace over other quantities, say kernel parameters, etc. one has to write their own procedure, like
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
locations = []
variances = []
lengthscales = []
noises = []
num_steps = 2000 if not smoke_test else 2
for i in range(num_steps):
optimizer.zero_grad()
loss = loss_fn(sgpr.model, sgpr.guide)
locations.append(sgpr.Xu.data.numpy().copy())
variances.append(sgpr.kernel.variance.item())
noises.append(sgpr.noise.item())
lengthscales.append(sgpr.kernel.lengthscale.item())
loss.backward()
optimizer.step()
losses.append(loss.item())
I propose the addition of trace_quantities in gp.util.train
As an example, tfp.math.minimize allows this
trace_fn = lambda traceable_quantities: {
"loss": traceable_quantities.loss,
"theta": theta,
}
num_steps = 150
trace = tfp.math.minimize(
loss_fn=loss,
num_steps=num_steps,
optimizer=tf.optimizers.Adam(learning_rate=0.01),
trace_fn=trace_fn,
)
Then, we can obtain the traces as follows:
fig, ax = plt.subplots(nrows=2, sharex=True, figsize=(6, 4))
ax[0].plot(range(num_steps), trace["loss"])
ax[1].plot(range(num_steps), trace["theta"])
By default, the traceable_quantities could be loss.
Note you can already do this by wrapping loss_fn in a tracing hook:
sgpr = ...
optimizer = ...
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
# Wrap loss_fn in a tracing hook.
locations = []
variances = []
noises = []
def tracing_loss_fn(model, guide):
loss = loss_fn(model, guide)
locations.append(sgpr.Xu.data.numpy().copy())
variances.append(sgpr.kernel.variance.item())
noises.append(sgpr.noise.item())
return loss
gp.util.train(sgpr, optimizer, tracing_loss_fn)