probability icon indicating copy to clipboard operation
probability copied to clipboard

Run fit_with_hmc with jit_compile activated

Open williamjamir opened this issue 1 year ago • 0 comments

I'm trying to run the following code:

import numpy as np
import tensorflow_probability as tfp
import tensorflow as tf

time_series_with_nans = [-1.0, 1.0, np.nan, 2.4, np.nan, 5]
observed_time_series = tfp.sts.MaskedTimeSeries(
    time_series=time_series_with_nans, is_missing=tf.math.is_nan(time_series_with_nans)
)

# Build model using observed time series to set heuristic priors.
linear_trend_model = tfp.sts.LocalLinearTrend(observed_time_series=observed_time_series)
model = tfp.sts.Sum([linear_trend_model], observed_time_series=observed_time_series)

# Fit model to data
parameter_samples, _ = tf.function(
    func=lambda ots: tfp.sts.fit_with_hmc(model, ots), jit_compile=True, autograph=False
)(observed_time_series)

Using JIT as suggested here on this comment: https://github.com/tensorflow/probability/issues/1704#issuecomment-1497773131 gives me the following error:

parameter_samples, _ = tf.function(
        func=lambda ots: tfp.sts.fit_with_hmc(model, ots),
        jit_compile=True,
        autograph=False)(observed_time_series)

test_jit_hmc.py:262: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.env/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py:153: in error_handler
    raise e.with_traceback(filtered_tb) from None
incrementality/prefect/flows/model_training.py:263: in <lambda>
    func=lambda ots: tfp.sts.fit_with_hmc(model, ots),
.env/lib/python3.11/site-packages/tensorflow_probability/python/sts/fitting.py:466: in fit_with_hmc
    variational_posterior = build_factored_surrogate_posterior(
.env/lib/python3.11/site-packages/tensorflow_probability/python/sts/fitting.py:173: in build_factored_surrogate_posterior
    return experimental_vi.build_factored_surrogate_posterior(
.env/lib/python3.11/site-packages/tensorflow_probability/python/internal/trainable_state_util.py:337: in build_stateful_trainable
    tf.nest.map_structure(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

t = <tf.Tensor 'fit_with_hmc/build_factored_surrogate_posterior/build_factored_surrogate_posterior/Normal_trainable_variables/normal/stateless_random_normal:0' shape=() dtype=float32>, n = 'loc'

>   lambda t, n=name: t if t is None else tf.Variable(t, name=n),
    value, expand_composites=True))
E   ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

.env/lib/python3.11/site-packages/tensorflow_probability/python/internal/trainable_state_util.py:338: ValueError

It looks like tfp.sts.fit_with_hmc involves creating variables as part of its execution, which raises the question:

  • How can I enable JIT for this method? Is this a limitation?

If yes, since GPU doesn't work for STS, and JAX as well (https://github.com/tensorflow/probability/issues/1646#issue-1434159113) are there any other alternatives to speed up? fit_with_hmc?

  • https://github.com/tensorflow/probability/issues/1704
  • https://github.com/tensorflow/probability/issues/1395

I'm using:

python 3.11

pip list | grep tensorflow
tensorflow                         2.17.0
tensorflow-probability             0.24.0

using v2.16 also produces the same error:

 pip list | grep tensor
tensorboard                        2.16.2
tensorboard-data-server            0.7.2
tensorflow                         2.16.1
tensorflow-io-gcs-filesystem       0.36.0
tensorflow-probability             0.24.0

williamjamir avatar Aug 13 '24 09:08 williamjamir