VINF
VINF copied to clipboard
Refactor flow model with tf.bijector
I tried on my own to improve on the keras model implementation of the flows which in my opinion is not as adapted to handle probabilistic layers as what the tensorflow probability package provides. Nevertheless, this attempt failed and I opened a stackoverflow issue to see if anyone had a solution to the problem I encountered. See if any of the answers from https://stackoverflow.com/questions/61717694/embed-trainable-bijector-into-keras-model/62284510#62284510 are valid
Any progress on this? Or a suggestion where to start?
Hi @kaijennissen,
I did not come back to this issue since I created it, so unfortunately no progress has been made.
Happy to guide you, what's your motivation for wanting to use the tf.bijector ? :)
Hi,
I'm interested in probabilistic deep learning and would like to use normalizing flows to build more expressive posteriors for the weights in a neural network.
In my opinion tensorflow-probability already provides all the necessary pieces to build such a model (chain a few bijectors, use them to transform a base distribution and then learn the parameters of the bijector).
I've build a small toy example which works as expected (code below).
The next steps was be to build a cusom layer that could replace the DistributionLambda
layer.
But when subclassing the DistributionLambda
layer the parameters of the bijector are not tracked (they are attributes of the class) .
I thought you maybe came across similar problems? Or do you think this is the wrong way and I should subclass tf.keras.Layers
?
def posterior_trainable_bijector(kernel_size, bias_size=0, dtype=None):
n = kernel_size + bias_size
c = np.log(np.expm1(1.0))
return tf.keras.Sequential(
[
tfp.layers.VariableLayer(2 * n, dtype=dtype),
tfp.layers.DistributionLambda(
lambda t: tfp.distributions.TransformedDistribution(
tfd.Independent(
tfd.Normal(loc=tf.zeros(n), scale=tf.ones(n)),
reinterpreted_batch_ndims=1,
),
tfp.bijectors.Chain(
bijectors=[
tfp.bijectors.Shift(t[..., n:]),
tfp.bijectors.Scale(
1e-5 + 0.01 * tf.math.softplus(c + t[..., :n])
),
]
),
)
),
]
)