optax icon indicating copy to clipboard operation
optax copied to clipboard

Clarify difference between `kl_divergence` and `convex_kl_divergence`

Open carlosgmartin opened this issue 4 weeks ago • 1 comments

The difference between kl_divergence and convex_kl_divergence should be clarified.

The description for kl_divergence is

Computes the Kullback-Leibler divergence (relative entropy) loss.

Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution.

The description for convex_kl_divergence is

Computes a convex version of the Kullback-Leibler divergence loss.

Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution. This version is jointly convex in p (targets) and q (log_predictions).

But the KL divergence is already jointly convex in P and Q, so it's not clear what the difference is. Is the difference that convex_kl_divergence is convex in log Q (which the description confusingly calls Q, even though that usually refers to the actual probabilities)? If so, this should be stated more clearly.

Implementation-wise, the only difference is that convex_kl_divergence adds the term $$\sum_i Q_i - \sum_i P_i$$.

carlosgmartin avatar Nov 28 '25 21:11 carlosgmartin

Hi! I would like to work on this issue. I can update the docstrings for kl_divergence and convex_kl_divergence to better reflect the differences in convexity and implementation mentioned above. Please assign this to me.

zer-art avatar Nov 29 '25 04:11 zer-art