probability icon indicating copy to clipboard operation
probability copied to clipboard

Unable to pass `aux` data to `tfp.math.minimize_stateless`

Open patel-zeel opened this issue 2 years ago • 1 comments

Hi,

I was trying to use tfp.math.minimize_stateless for one of my applications and found that it automatically optimizes all the arguments passed to the loss_fun. For example, here is a MWE.

import optax
import tensorflow_probability.substrates.jax as tfp

def loss_fun(x, a):
  return (((x['param1'] - a) + (x['param2'] - (a+1)))**2).sum()

N = 3
init_params = lambda: {'param1': jnp.zeros(N), 'param2': jnp.ones(N)}
a = 2.0

optimizer = optax.adam(learning_rate=0.1)
params, losses = tfp.math.minimize_stateless(loss_fun, (init_params(), a), num_steps=1000, optimizer=optimizer)
print(params)

Output

({'param1': DeviceArray([1.0000008, 1.0000008, 1.0000008], dtype=float32),
  'param2': DeviceArray([1.9999989, 1.9999989, 1.9999989], dtype=float32)},
 DeviceArray(0.9999999, dtype=float32))

The output shows that a is optimized as well but I want to consider it a fixed parameter. Is there a way to do it with tfp.math.minimize_stateless? or in other words, how to pass auxiliary data to the loss function?

I see that jax.grad takes argnums argument to specify which parameters to optimize but something similar is not present (AFAIK) in tfp.math.minimize_stateless.

patel-zeel avatar Jun 10 '22 05:06 patel-zeel

Can you do

params, losses = tfp.math.minimize_stateless(
    lambda x: loss_fun(x, a),
    init_params(),
    num_steps=1000,
    optimizer=optimizer)

csuter avatar Jun 15 '22 15:06 csuter

Thanks for the suggestion, @csuter. I think that should solve it. I don't remember the context now since it's been a long time, but I thought of closing the issue to make room for the new ones :)

patel-zeel avatar Jan 14 '23 16:01 patel-zeel