Mava icon indicating copy to clipboard operation
Mava copied to clipboard

[FEATURE] Use numpyro

Open sash-a opened this issue 1 year ago • 1 comments

Feature

Switch distribution libraries...again!

The problem with tfp is that if we want to run non-shared parameters we need vmap the apply function (over params and observations), but this means that it would return a jax array of tfp distributions and since a tfp distribution is not a jax type this cannot work. But numpyro's distribution objects are jax types and are vmappable! So we can just use them as a drop in replacement, here's an proof of concept:

import jax.numpy as jnp
import jax
import flax.linen as nn
import numpyro

class Network(nn.Module):
    @nn.compact
    def __call__(self, x):
        return numpyro.distributions.Categorical(logits=nn.Dense(5)(x))


n_agents = 4

key = jax.random.PRNGKey(3)
keys = jax.random.split(key, n_agents)

x = jnp.arange(5, dtype=float)
xs = x[jnp.newaxis].repeat(n_agents, axis=0)

net = Network()
params = jax.vmap(net.init)(keys, xs)

dist = jax.jit(jax.vmap(net.apply))(params, xs)
action = dist.sample(key, (n_agents,))  # Array([2, 3, 3, 2], dtype=int32)
dist.log_prob(action)  # Array([-1.9792106 , -1.3051271 , -0.10164165, -2.1678243 ], dtype=float32)
dist.entropy()  # Array([0.52041006, 1.3267    , 0.37860775, 0.91515315], dtype=float32)

Replacing the numpyro.distributions.Categorical with a tfp.Categorical gives the following error: ValueError: Attempt to convert a value (<object object at 0x7fa1a191bfa0>) with an unsupported type (<class 'object'>) to a Tensor. because distributions are objects which are not jax types

sash-a avatar Aug 15 '24 14:08 sash-a

Some additional testing shows that tfp and numpyro produce the same out puts, needs more investigation to 100% confirms this, eg for larger sample shapes and through backprop

import jax
import jax.numpy as jnp
import numpyro.distributions as npd
from numpyro.distributions.transforms import ParameterFreeTransform
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd


class TanhTransform(ParameterFreeTransform):
    codomain = npd.constraints.open_interval(-1, 1)
    sign = 1

    def __call__(self, x):
        return jnp.tanh(x)

    def _inverse(self, y):
        return jnp.atanh(y)

    def log_abs_det_jacobian(self, x, y, intermediates=None):
        #  This formula is mathematically equivalent to
        #  `tf.log1p(-tf.square(tf.tanh(x)))`, however this code is more numerically
        #  stable.
        #  Derivation:
        #    log(1 - tanh(x)^2)
        #    = log(sech(x)^2)
        #    = 2 * log(sech(x))
        #    = 2 * log(2e^-x / (e^-2x + 1))
        #    = 2 * (log(2) - x - log(e^-2x + 1))
        #    = 2 * (log(2) - x - softplus(-2x))
        return 2.0 * (jnp.log(2.0) - x - jax.nn.softplus(-2.0 * x))


loc = jnp.array([0, 1, 2], dtype=float)
scale = jnp.array([3, 4, 5], dtype=float)
tfp_norm = tfd.Normal(loc=loc, scale=scale)
tfp_tanh = tfb.Tanh()

npr_norm = npd.Normal(loc=loc, scale=scale)
npr_tahn = TanhTransform()

print(tfp_tanh(tfp_norm.sample(seed=jax.random.key(1))))
print(npr_tahn(npr_norm.sample(key=jax.random.key(1))))

sash-a avatar Oct 21 '24 09:10 sash-a