optax
optax copied to clipboard
Add `axis` and `where` arguments to loss functions
Feature request: Add the following arguments:
axis: reduction axis (default-1)where: reduction mask (defaultNone)
to the following loss functions:
- [ ]
convex_kl_divergence - [ ]
cosine_distance - [ ]
cosine_similarity - [ ]
kl_divergence - [ ]
sigmoid_focal_loss - [ ]
softmax_cross_entropy - [ ]
softmax_cross_entropy_with_integer_labels
I can submit a PR for this.
Hello @carlosgmartin,
Yes, that would be great. Thanks for catching this!
Great. Once #898 is merged I'll put together a PR.