[FEATURE] Use numpyro
Feature
Switch distribution libraries...again!
The problem with tfp is that if we want to run non-shared parameters we need vmap the apply function (over params and observations), but this means that it would return a jax array of tfp distributions and since a tfp distribution is not a jax type this cannot work. But numpyro's distribution objects are jax types and are vmappable! So we can just use them as a drop in replacement, here's an proof of concept:
import jax.numpy as jnp
import jax
import flax.linen as nn
import numpyro
class Network(nn.Module):
@nn.compact
def __call__(self, x):
return numpyro.distributions.Categorical(logits=nn.Dense(5)(x))
n_agents = 4
key = jax.random.PRNGKey(3)
keys = jax.random.split(key, n_agents)
x = jnp.arange(5, dtype=float)
xs = x[jnp.newaxis].repeat(n_agents, axis=0)
net = Network()
params = jax.vmap(net.init)(keys, xs)
dist = jax.jit(jax.vmap(net.apply))(params, xs)
action = dist.sample(key, (n_agents,)) # Array([2, 3, 3, 2], dtype=int32)
dist.log_prob(action) # Array([-1.9792106 , -1.3051271 , -0.10164165, -2.1678243 ], dtype=float32)
dist.entropy() # Array([0.52041006, 1.3267 , 0.37860775, 0.91515315], dtype=float32)
Replacing the numpyro.distributions.Categorical with a tfp.Categorical gives the following error: ValueError: Attempt to convert a value (<object object at 0x7fa1a191bfa0>) with an unsupported type (<class 'object'>) to a Tensor. because distributions are objects which are not jax types
Some additional testing shows that tfp and numpyro produce the same out puts, needs more investigation to 100% confirms this, eg for larger sample shapes and through backprop
import jax
import jax.numpy as jnp
import numpyro.distributions as npd
from numpyro.distributions.transforms import ParameterFreeTransform
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
class TanhTransform(ParameterFreeTransform):
codomain = npd.constraints.open_interval(-1, 1)
sign = 1
def __call__(self, x):
return jnp.tanh(x)
def _inverse(self, y):
return jnp.atanh(y)
def log_abs_det_jacobian(self, x, y, intermediates=None):
# This formula is mathematically equivalent to
# `tf.log1p(-tf.square(tf.tanh(x)))`, however this code is more numerically
# stable.
# Derivation:
# log(1 - tanh(x)^2)
# = log(sech(x)^2)
# = 2 * log(sech(x))
# = 2 * log(2e^-x / (e^-2x + 1))
# = 2 * (log(2) - x - log(e^-2x + 1))
# = 2 * (log(2) - x - softplus(-2x))
return 2.0 * (jnp.log(2.0) - x - jax.nn.softplus(-2.0 * x))
loc = jnp.array([0, 1, 2], dtype=float)
scale = jnp.array([3, 4, 5], dtype=float)
tfp_norm = tfd.Normal(loc=loc, scale=scale)
tfp_tanh = tfb.Tanh()
npr_norm = npd.Normal(loc=loc, scale=scale)
npr_tahn = TanhTransform()
print(tfp_tanh(tfp_norm.sample(seed=jax.random.key(1))))
print(npr_tahn(npr_norm.sample(key=jax.random.key(1))))