TFP optimizers should take an "additional_args" argument which is passed through to value_and_gradients_function
Consider the following scenario (shamelessly copied from TFP tutorial)
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
np.random.seed(12345)
@tf.function
def quadratic(minimum, x):
with tf.GradientTape() as g:
g.watch(x)
out = tf.reduce_sum(input_tensor=scales * (x - minimum)**2, axis=-1)
grad = g.gradient(out, x)
return out, grad
We wish to minimize quadratic with respect to a specific argument minimum. However, tfp.optimizers does not allow passing any additional arguments "straight through" to the value_and_gradients function.
This is a problem as it forces eager execution like as follows:
def quadratic_wrapper(minimum):
return (lambda x: quadratic(minimum, x))
Now quadratic wrapper cannot be turned into a tf.function as it returns a closure, not a tensor. So we must do something like this:
tfp.optimizer.lbfgs_minimize(quadratic_wrapper(minimum), ...
Now the above also cannot be turned into a tf.function as it takes a python function as an argument.
Basically, we're forced into eager execution, when, if TFP allowed to pass arguments to value_and_gradients_function in a "passthrough" manner, we could overcome the above as "minimum" in the above example can be represented as a tensor.
I've hacked together a fix for this, and I notice a 2x speedup in my code by implementing something like this. Would be nice if this was officially supported.
I think we'd normally do this to work around the python function argument:
fn = tf.function(lambda: tfp.optimizer.lbfgs_minimize(quadratic_wrapper(minimum), ...))
...
fn() # <--- compiles and executes the lambda
(I'd usually also include autograph=False and jit_compile=True in the tf.function call)
Hi csuter,
Thanks for the quick reply. To be more specific, I'd like to do run tfp.optimizer.lbfgs_minimize() several times with several different "minimum" arguments. Now, the problem with this is that it would require a compilation per execution of fn() above. The 1-2 second cost of compilation is too expensive for me in this case as I'm calling this minimize function several times a second.
Hopefully this helps clarify the issue?
Can you plumb minimum through from the outer lambda?
On Fri, Apr 22, 2022 at 13:58 Mohit Rajpal @.***> wrote:
Hi csuter,
Thanks for the quick reply. To be more specific, I'd like to do run tfp.optimizer.lbfgs_minimize() several times with several different "minimum" arguments. Now, the problem with this is that it would require a compilation per execution of fn() above. The 1-2 second cost of compilation is too expensive for me in this case as I'm calling this minimize function several times a second.
Hopefully this helps clarify the issue?
— Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1553#issuecomment-1106738511, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABG2GMY6R725FNZOHW4J3TVGLSDHANCNFSM5UCSUX2A . You are receiving this because you commented.Message ID: @.***>
Unfortunately not without causing retracing every time the outer lambda is called.
The problem is that the TF optimizer cannot "see through" the inner lambda (return (lambda x: quadratic(minimum, x))) that it's merely a currying function for another tf.function.
This is perhaps demonstrated with the following code:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import datetime
np.random.seed(12345)
dim = 100
batches = 1000
minimum = np.random.randn(batches, dim)
scales = np.exp(np.random.randn(batches, dim))
@tf.function
def quadratic(minimum, x):
with tf.GradientTape() as g:
g.watch(x)
out = tf.reduce_sum(input_tensor=scales * (x - minimum)**2, axis=-1)
grad = g.gradient(out, x)
return out, grad
def quadratic_wrapper(minimum):
#This creates a python closure
return (lambda x: quadratic(minimum, x))
start = tf.ones((batches, dim), dtype='float64')
@tf.function
def optimize(minimum):
#forced retrace every time as quadratic wrapper returns a new
#python object every time
return tfp.optimizer.lbfgs_minimize(
quadratic_wrapper(minimum), initial_position=start,
stopping_condition=tfp.optimizer.converged_all,
max_iterations=100,
tolerance=1e-8)
print(datetime.datetime.now())
for i in range(10):
optimize(minimum)
print(datetime.datetime.now())
@tf.function
def quadratic(x):
#traces once, minimum is bound on first trace
with tf.GradientTape() as g:
g.watch(x)
out = tf.reduce_sum(input_tensor=scales * (x - minimum)**2, axis=-1)
grad = g.gradient(out, x)
return out, grad
@tf.function
def optimize():
#traces once
return tfp.optimizer.lbfgs_minimize(
quadratic_wrapper(minimum), initial_position=start,
stopping_condition=tfp.optimizer.converged_all,
max_iterations=100,
tolerance=1e-8)
print(datetime.datetime.now())
for i in range(10):
optimize()
print(datetime.datetime.now())
In the above, the first optimize(minimum) takes several times longer than the second optimize() calls.
As far as I know there's no way to do partial application in python without the creation of a python object (currying) which forces retracing.
Sorry I haven't had more time to engage on this. I don't think retracing is necessarily the issue though. If you put a print statement inside your optimize functions you can see how many times they are traced (the print will only happen on the tracing pass).
If things are still too slow, can you try @tf.function(autograph=False, jit_compile=True) and see if there's an improvement on the post-compilation run times?