optax
optax copied to clipboard
Add `axis` and `where` arguments to loss functions
Feature request: Add the following arguments:
-
axis
: reduction axis (default-1
) -
where
: reduction mask (defaultNone
)
to the following loss functions:
- [ ]
convex_kl_divergence
- [ ]
cosine_distance
- [ ]
cosine_similarity
- [ ]
kl_divergence
- [ ]
sigmoid_focal_loss
- [ ]
softmax_cross_entropy
- [ ]
softmax_cross_entropy_with_integer_labels
I can submit a PR for this.
Hello @carlosgmartin,
Yes, that would be great. Thanks for catching this!
Great. Once #898 is merged I'll put together a PR.