probability
probability copied to clipboard
Unable to pass `aux` data to `tfp.math.minimize_stateless`
Hi,
I was trying to use tfp.math.minimize_stateless
for one of my applications and found that it automatically optimizes all the arguments passed to the loss_fun
. For example, here is a MWE.
import optax
import tensorflow_probability.substrates.jax as tfp
def loss_fun(x, a):
return (((x['param1'] - a) + (x['param2'] - (a+1)))**2).sum()
N = 3
init_params = lambda: {'param1': jnp.zeros(N), 'param2': jnp.ones(N)}
a = 2.0
optimizer = optax.adam(learning_rate=0.1)
params, losses = tfp.math.minimize_stateless(loss_fun, (init_params(), a), num_steps=1000, optimizer=optimizer)
print(params)
Output
({'param1': DeviceArray([1.0000008, 1.0000008, 1.0000008], dtype=float32),
'param2': DeviceArray([1.9999989, 1.9999989, 1.9999989], dtype=float32)},
DeviceArray(0.9999999, dtype=float32))
The output shows that a
is optimized as well but I want to consider it a fixed parameter. Is there a way to do it with tfp.math.minimize_stateless
? or in other words, how to pass auxiliary data to the loss function?
I see that jax.grad
takes argnums
argument to specify which parameters to optimize but something similar is not present (AFAIK) in tfp.math.minimize_stateless
.
Can you do
params, losses = tfp.math.minimize_stateless(
lambda x: loss_fun(x, a),
init_params(),
num_steps=1000,
optimizer=optimizer)
Thanks for the suggestion, @csuter. I think that should solve it. I don't remember the context now since it's been a long time, but I thought of closing the issue to make room for the new ones :)