keras-cv
keras-cv copied to clipboard
Adds different class reduction scheme in `FocalLoss`
Short Description
Currently FocalLoss reduces multi - label problems by a mean (see below):
https://github.com/keras-team/keras-cv/blob/ff20079aebe8e80bc355c279c9c5541a7c823535/keras_cv/losses/focal.py#L90
However a strictly correct formulation would be a sum. Original implementaion ( Facebook Research, pytorch ) lets the user decide for a sum or a mean. Other usage of focal loss also introduces a sum over non-zero elements scheme (see below).
I think the API should either:
- provide no class reduction at all, and let the users implements their own.
- By the mean of an argument, provides either no class reduction, or
sum,meanorsum_over_non_zeroclass reduction
Please Note that it is fundamentaly different from the reduction parameter of keras losses that applies primarily on batches. Here we are talking on how FocalLoss, a modified binary cross-entropy loss, should be extended from single class to multi-class, multi-label cases.
What do you think @LukeWood @quantumalaviya ?
Papers
Existing Implementations
- Pytorch: https://pytorch.org/vision/stable/_modules/torchvision/ops/focal_loss.html
- Original Facebook research implementation: https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
- Implementation of FSAF : https://github.com/xuannianz/FSAF/blob/master/losses.py#L22-L60
Other Information
First discussed in #672
Gotcha, so do mean and sum in focal loss provide different results?
I think we should just let it accept Reduction similar to other keras losses? @atuleu
@atuleu , Could you please make the required changes as per the comment in the linked PR. Thanks