optax icon indicating copy to clipboard operation
optax copied to clipboard

Clarify difference between kl_divergence and convex_kl_divergence

Open zer-art opened this issue 3 weeks ago • 0 comments

Description

This PR addresses issue #1513 by clarifying the documentation for kl_divergence and convex_kl_divergence.

The previous documentation did not explicitly state the mathematical difference between the two implementations, nor did it explain the specific use case for the convex variant.

Changes made:

  • Updated kl_divergence docstring to identify it as the standard definition.
  • Updated convex_kl_divergence docstring to:
    • Explicitly mention the added term: sum(exp(log_predictions)) - sum(targets).
    • Clarify that this is the "Generalized KL Divergence" intended for unnormalized distributions.
    • Improved phrasing for readability and elegance.

Fixes #1513

Type of change

  • [x] Documentation update (no code logic change)

zer-art avatar Nov 29 '25 05:11 zer-art