numpyro
numpyro copied to clipboard
Log Uniform distribution
How about a Log Uniform distribution, i.e., the log of a variable is uniformly distributed. Implementation:
class LogUniform(dist.Uniform):
def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape
sample = jax.random.uniform(key, shape=shape, minval=jnp.log(self.low), maxval=jnp.log(self.high))
return jnp.exp(sample)
A more polished version would check that low and high are both > 0. This is rather simple for users to do, but quite convenient.
Please feel free to submit a PR. I think we can also use:
d = dist.TransformedDistribution(dist.Uniform(low, high), dist.transforms.ExpTransform())
can be close by #1423
Thanks @yayami3!