optax
optax copied to clipboard
Clarify difference between kl_divergence and convex_kl_divergence
Description
This PR addresses issue #1513 by clarifying the documentation for kl_divergence and convex_kl_divergence.
The previous documentation did not explicitly state the mathematical difference between the two implementations, nor did it explain the specific use case for the convex variant.
Changes made:
- Updated
kl_divergencedocstring to identify it as the standard definition. - Updated
convex_kl_divergencedocstring to:- Explicitly mention the added term:
sum(exp(log_predictions)) - sum(targets). - Clarify that this is the "Generalized KL Divergence" intended for unnormalized distributions.
- Improved phrasing for readability and elegance.
- Explicitly mention the added term:
Fixes #1513
Type of change
- [x] Documentation update (no code logic change)