pyro icon indicating copy to clipboard operation
pyro copied to clipboard

[FR] gp.util.train to return a trace over multiple quantities rather than just loss

Open nipunbatra opened this issue 3 years ago • 1 comments

Currently, gp.util.train return the trace over loss (loss v/s iterations)

losses = gp.util.train(vsgp, num_steps=num_steps)

If one needs to get a trace over other quantities, say kernel parameters, etc. one has to write their own procedure, like


loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
locations = []
variances = []
lengthscales = []
noises = []
num_steps = 2000 if not smoke_test else 2
for i in range(num_steps):
    optimizer.zero_grad()
    loss = loss_fn(sgpr.model, sgpr.guide)
    locations.append(sgpr.Xu.data.numpy().copy())
    variances.append(sgpr.kernel.variance.item())
    noises.append(sgpr.noise.item())
    lengthscales.append(sgpr.kernel.lengthscale.item())
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

I propose the addition of trace_quantities in gp.util.train

As an example, tfp.math.minimize allows this

trace_fn = lambda traceable_quantities: {
    "loss": traceable_quantities.loss,
    "theta": theta,
}

num_steps = 150

trace = tfp.math.minimize(
    loss_fn=loss,
    num_steps=num_steps,
    optimizer=tf.optimizers.Adam(learning_rate=0.01),
    trace_fn=trace_fn,
)

Then, we can obtain the traces as follows:

fig, ax = plt.subplots(nrows=2, sharex=True, figsize=(6, 4))
ax[0].plot(range(num_steps), trace["loss"])
ax[1].plot(range(num_steps), trace["theta"])

By default, the traceable_quantities could be loss.

nipunbatra avatar Mar 03 '22 00:03 nipunbatra

Note you can already do this by wrapping loss_fn in a tracing hook:

sgpr = ...
optimizer = ...
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss

# Wrap loss_fn in a tracing hook.
locations = []
variances = []
noises = []
def tracing_loss_fn(model, guide):
    loss = loss_fn(model, guide)
    locations.append(sgpr.Xu.data.numpy().copy())
    variances.append(sgpr.kernel.variance.item())
    noises.append(sgpr.noise.item())
    return loss

gp.util.train(sgpr, optimizer, tracing_loss_fn)

fritzo avatar Mar 03 '22 14:03 fritzo