Error initializing tfb.AutoregressiveNetwork using jax substrate
Trying to initialize an instance of tfb.AutoregressiveNetwork using the jax substrate fails with an AttributeError.
With the example usage from the docs:
from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijector
tfb.AutoregressiveNetwork(params=2, hidden_units=[10,10])
raises the error:
File .../site-packages/tensorflow_probability/substrates/jax/bijectors/masked_autoregressive.py:967, in AutoregressiveNetwork.__init__(self, params, event_shape, conditional, conditional_event_shape, conditional_input_layers, hidden_units, input_order, hidden_degrees, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, kernel_constraint, bias_constraint, validate_args, **kwargs)
965 self._kernel_regularizer = kernel_regularizer
966 self._bias_regularizer = bias_regularizer
--> 967 self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
968 self._bias_constraint = bias_constraint
969 self._validate_args = validate_args
AttributeError: module 'tensorflow_probability.python.internal.backend.jax.numpy_keras' has no attribute 'constraints'
Tested with: tfp version: 0.19.0 jax version: 0.4.4 and 0.3.25
Bijector code that uses keras from TF is not going to work well in JAX
(other examples would be Glow and PixelCNN). You can look at
masked_autoregressive_test.py to see which tests are disabled with JAX --
it's many of them.
But I think you could probably use the bijector as part of a flax
module, possibly even as the return value of call.
If you wanted to contribute some kind of stateless AutoregressiveNetwork for JAX I think it would be a nice PR.
Brian Patton | Software Engineer | @.***
On Tue, Feb 28, 2023 at 9:12 AM Giles Harper-Donnelly < @.***> wrote:
Trying to initialize an instance of tfb.AutoregressiveNetwork using the jax substrate fails with an AttributeError.
With the example usage from the docs https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/AutoregressiveNetwork :
from tensorflow_probability.substrates import jax as tfptfb = tfp.bijector tfb.AutoregressiveNetwork(params=2, hidden_units=[10,10])
raises the error:
AttributeError: module 'tensorflow_probability.python.internal.backend.jax.numpy_keras' has no attribute 'constraints'
Tested with: tfp version: 0.19.0 jax version: 0.4.4 and 0.3.25
— Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1699, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSIZTRJ2KVTK75I7HI7TWZYBTBANCNFSM6AAAAAAVKY3BJI . You are receiving this because you are subscribed to this thread.Message ID: @.***>