optax icon indicating copy to clipboard operation
optax copied to clipboard

Rename `optax.perturbations` distribution method `log_prob` to `unnormalized_log_prob`

Open carlosgmartin opened this issue 9 months ago • 7 comments

Each distribution under optax.perturbations has a method called log_prob that returns the logarithm of the unnormalized probability density at a given point.

I propose renaming this method to unnormalized_log_prob, like in TensorFlow Probability, so users don't confuse it with the actual (normalized) log probability.

If this is acceptable, I can submit a PR for it.

carlosgmartin avatar Mar 01 '25 20:03 carlosgmartin

I think the most important point is that the method make_perturbed_fun does not require normalized log probabilities hence the advantage of highlighting that with the unnormalized_log_prob name. This can at least be highlighted in the docs of make_perturbed_fun.

The main issue is that there is no class for "noises" like Gumbel or Normal. I believe the underlying idea (pinging the author of the code @q-berthet) was that the make_perturbed_fun could be used with standard sampling libraries like distrax which share similar attributes (sample, log_prob for their distributions).

I think it would be best to consolidate this situation either

  • by asking for two arguments (noise_sampler, noise_unnormalized_log_prob) in the signature of make_perturbed_fun
  • or by making a proper Noise class (which can then have sampler, log, log_prob etc...) I would be more in favor of the first solution, which will require some future deprecated behavior. Ideally we would like not to recode a library of distributions like distrax so making our own class may clash with well thought existing libraries.

Thanks for looking into that

vroulet avatar Mar 11 '25 21:03 vroulet

Sorry I had not seen #1213. So before considering merging #1213, we should create a proper class or simply make these distributions some simple NamedTuples with sample, and unnormalized_log_prob as attributes.

vroulet avatar Mar 11 '25 22:03 vroulet

You raise an excellent point. I actually had the same thought: It would be better to make the interface function-based rather than object-based, since functions are more "nimble" (and let us easily reuse existing JAX code like jax.random.normal and jax.scipy.stats.norm.logpdf).

In that case, I'd deprecate the distribution classes, deprecate the noise parameter, and add a noise_fn parameter that outputs a sample and its unnormalized log-probability (perhaps it might be more efficient to compute them jointly rather than separately, for some generative processes).

If that sounds good, I can submit a PR for it.

carlosgmartin avatar Mar 12 '25 03:03 carlosgmartin

Hello @carlosgmartin ,

And thanks @vroulet for the ping.

As said above, the main idea was to be compatible with distrax, which typically has a sample and log_prob function for each noise distribution, in a class. It used to be tested for in jaxopt, but not anymore in optax.

I don't have a strong opinion regarding the best way to handle this, but I find the current setup practical for two reasons

  1. It's simple on the user side: you specify a noise, sampling and noise log-prob (un-normalized) are automatically used well together, there is no risk of accidentally changing one and not the other.

  2. Having a single function feels a bit clunky: why do you need to compute the log_prob when you are sampling? Why do you need an rng to compute a log_prob? These are two different things in my opinion.

If there is another way to do these two things, happy to use it. The question of compatibility with distrax remains.

I think it's nice to add other noise distributions, a natural question is whether there should be a general distributions / random module within optax rather that it being in the perturbations part.

q-berthet avatar Mar 12 '25 10:03 q-berthet

@q-berthet Thanks for your comment.

there is no risk of accidentally changing one and not the other.

I think the same applies to the proposed noise_fn.

why do you need to compute the log_prob when you are sampling?

To quote Distrax:

Distrax distributions implement the method sample_and_log_prob, which provides samples and their log-probability in one line. For some distributions, this is more efficient than calling separately sample and log_prob

and FlowJAX:

sample_and_log_prob: Sample the distribution and return the samples with their log probabilities. For transformed distributions (especially flows), this will generally be more efficient than calling the methods separately.

The disadvantage of the object-based approach is that one needs to create an object class for every possible distribution, whereas functions are easier to make, compose, and reuse.

carlosgmartin avatar Mar 12 '25 18:03 carlosgmartin

@carlosgmartin sorry for the delay.

  • So I'd prefer to remove the classes and use functions.
  • Having indeed one function that both samples and provides the unnormalized log prob actually makes sense when trying to implement stochastic computational graphs (see e.g. https://arxiv.org/abs/1802.05098). In other words having access to the unnormalized log prob is a bit as having access to the jvp. Anyway, happy to move on with a change of api if it's ok with both of you @q-berthet, @carlosgmartin.

vroulet avatar Apr 11 '25 16:04 vroulet

@vroulet Sounds good to me.

carlosgmartin avatar Apr 12 '25 02:04 carlosgmartin