numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Log Uniform distribution

Open andrewfowlie opened this issue 2 years ago • 1 comments

How about a Log Uniform distribution, i.e., the log of a variable is uniformly distributed. Implementation:

class LogUniform(dist.Uniform):
    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape
        sample = jax.random.uniform(key, shape=shape, minval=jnp.log(self.low), maxval=jnp.log(self.high))
        return jnp.exp(sample)

A more polished version would check that low and high are both > 0. This is rather simple for users to do, but quite convenient.

andrewfowlie avatar Apr 28 '22 02:04 andrewfowlie

Please feel free to submit a PR. I think we can also use:

d = dist.TransformedDistribution(dist.Uniform(low, high), dist.transforms.ExpTransform())

fehiepsi avatar Apr 28 '22 03:04 fehiepsi

can be close by #1423

yayami3 avatar Dec 17 '22 00:12 yayami3

Thanks @yayami3!

fehiepsi avatar Dec 17 '22 00:12 fehiepsi