probability icon indicating copy to clipboard operation
probability copied to clipboard

tf.vectorized_map not compatible with LinearGaussianStateSpaceModel forward_filter?

Open Qiustander opened this issue 2 years ago • 3 comments

Hi all, I have a batch of data and want to conduct the Kalman filter for each observation set. My data is observations =(batch, num_time_lens, feature_dim) so I use tf.vectorized_map for parallel computation. I used TFP's official implementation of Kalman filter.

However, the tf.vectorized_map works fine with tf.function wrapped version of Kalman filter but fails when XLA compile is used. The reproducible example is (from TFP docstrings):

import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
ndims = 2
step_std = 1.0
noise_std = 5.0
model = tfd.LinearGaussianStateSpaceModel(
    num_timesteps=100,
    transition_matrix=tf.linalg.LinearOperatorIdentity(ndims),
    transition_noise=tfd.MultivariateNormalDiag(
        scale_diag=step_std**2 * tf.ones([ndims])),
    observation_matrix=tf.linalg.LinearOperatorIdentity(ndims),
    observation_noise=tfd.MultivariateNormalDiag(
        scale_diag=noise_std**2 * tf.ones([ndims])),
    initial_state_prior=tfd.MultivariateNormalDiag(
        scale_diag=tf.ones([ndims])))

"""
Generate data 
"""
x = model.sample(10) # Sample from the prior on sequences of observations.

def kalman_filter_wrapper(input):
    _, filtered_means, filtered_covs, _, _, _, _ = model.forward_filter(input)
    return filtered_means

@tf.function(jit_compile=True)
def run_sim():
    means = tf.vectorized_map(kalman_filter_wrapper, x)
    return means

d = run_sim()

The error is:

2023-11-06 09:30:27.757506: W tensorflow/core/framework/op_kernel.cc:1828] OP_REQUIRES failed at xla_ops.cc:503 : INVALID_ARGUMENT: Detected unsupported operations when trying to compile graph __inference_run_sim_14763[_XlaMustCompile=true,config_proto=3175580994766145631,executor_type=11160318154034397263] on XLA_CPU_JIT: TensorListReserve (No registered 'TensorListReserve' OpKernel for XLA_CPU_JIT devices compatible with node {{function_node __inference_while_fn_14694}}{{node while_init/TensorArrayV2_11}}
	 (OpKernel was found, but attributes didn't match) Requested Attributes: element_dtype=DT_VARIANT, shape_type=DT_INT32){{function_node __inference_while_fn_14694}}{{node while_init/TensorArrayV2_11}}

What is TensorListReserve operation? Is there any work-around method? Thanks

Qiustander avatar Nov 06 '23 01:11 Qiustander

Does running model.forward_filter(x) not work? Most TFP distributions are built to natively vectorize and broadcast across parameters.

brianwa84 avatar Nov 06 '23 17:11 brianwa84

Does running model.forward_filter(x) not work? Most TFP distributions are built to natively vectorize and broadcast across parameters.

Hi thanks for your reply. It did work for batched input. But I am curious about TensorListReverse operation which does not exist in the source code. Could you answer this question? Thanks

Qiustander avatar Nov 08 '23 05:11 Qiustander

TensorList* is how TF tracks gradients for while loops. But the representation is different inside XLA from outside, so generally can't cross the jit_compile=True boundary. At least, that's mostly where we've seen similar exceptions in the past.

On Wed, Nov 8, 2023, 12:24 AM Qiuliang Ye @.***> wrote:

Does running model.forward_filter(x) not work? Most TFP distributions are built to natively vectorize and broadcast across parameters.

Hi thanks for your reply. It did work for batched input. But I am curious about TensorListReverse operation which does not exist in the source code. Could you answer this question? Thanks

— Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1767#issuecomment-1801107944, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI4WFK2XRF5CNV3BN2LYDMJPBAVCNFSM6AAAAAA66YH336VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQMBRGEYDOOJUGQ . You are receiving this because you commented.Message ID: @.***>

brianwa84 avatar Nov 08 '23 15:11 brianwa84