ignite icon indicating copy to clipboard operation
ignite copied to clipboard

Possible improvements for Accuracy

Open Yura52 opened this issue 4 years ago • 10 comments

In full detail the feature request is described here, below is a quick recap.

There are two inconveniences I experience with the current interface of Accuracy.

1. Inconsistent input format for binary classification and multiclass problems

In the first case, Accuracy expects labels as input, whilst in the second case it expects probabilities/logits. I am a somewhat experienced Ignite user and I still get confused by this behavior.

2. No shortcuts for saying "I want to pass logits/probabilities as input"

In practice, I have never used Accuracy in the following manner for binary classification:

accuracy = Accuracy()

Instead, I always do one of the following:

accuracy = Accuracy(transform=lambda x: torch.round(torch.sigmoid(x)))
# either
accuracy = Accuracy(transform=lambda x: torch.round(x))

Suggested solution for both problems: let the user explicitly say in which form input will be passed:

import enum
class Accuracy(...):
    class Mode(enum.Enum):
        LABELS = enum.auto()
        PROBABILITIES = enum.auto()
        LOGITS = enum.auto()

    def __init__(self, mode=Mode.LABELS, ...):
        ...

The suggested interface can be also extended to support custom thresholds by adding the __call__ method to the Mode class.

Yura52 avatar May 31 '20 14:05 Yura52

@WeirdKeksButtonK I really appreciate this API ! Thank you very much 👍🏻

sdesrozis avatar Jun 04 '20 20:06 sdesrozis

Hello, I believe I could be assigned to this issue, since I have a PR for it

vcarpani avatar Oct 12 '20 14:10 vcarpani

@vcarpani sure ! On Github we can not assign any user to the issue but only those from the project or who participated in the conversation here.

vfdev-5 avatar Oct 12 '20 14:10 vfdev-5

Hi everyone, I would like to try improving this issue.

sallycaoyu avatar Feb 25 '23 01:02 sallycaoyu

Sure @sallycaoyu , please check also all related PRs and mentions.

vfdev-5 avatar Feb 25 '23 08:02 vfdev-5

For now, I am trying to finish implementing a binary_mode for binary and multilabel types to transform probabilities and logits into 0s and 1s as this PR has done. And if that works well, then I can consider how to add more flexibility to multiclass like issue #822 suggests.

Does that sound like a good plan? Or would your like Ignite to have a mode similar to what this issue suggests, i.e., mode in one of [binary, multiclass, multilabel] instead of one of [unchanged, probabilities, logits]? The former way will lead to more modifications to what we have right now, like removing is_multilabel and replacing it with mode for Accuracy, Precision, Recall, ClassificationReport, because now multilabel will be one option of mode.

sallycaoyu avatar Feb 27 '23 22:02 sallycaoyu

@sallycaoyu thanks for the update! I think we can continue with mode as [unchanged, probabilities, logits, labels?]. Can you please sketch up with code snippets new API usage, emphasizing on "before" and "after". For example:

### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits 

### after
acc = Accuracy(mode=Accuracy.LOGITS)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C), (N, )

acc = Accuracy(mode=Accuracy.LABELS)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass labels  : (N, ), (N, )

etc

vfdev-5 avatar Feb 27 '23 23:02 vfdev-5

Sure! Suppose we have:

class Accuracy
       def __init__(
            self,
            output_transform: Callable = lambda x: x,
            is_multilabel: bool = False,
            device: Union[str, torch.device] = torch.device("cpu"),
            mode: str = 'unchanged',
            threshold: Union[float, int] = 0.5
        )
          .....

Then, for binary and multilabel data:

### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (0s and 1s) : (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (0s and 1s) : (N, C, ...), (N, C, ...)



### after
# LOGITS MODE
acc = Accuracy(mode='logits', threshold = 3.25)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary logits (float in [-inf, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='logits', threshold = 3.25, is_multilabel = True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel logits (float in [-inf, inf]): (N, C, ...), (N, C, ...) 

# in this case, Accuracy will transform any value < 3.25 to be 0, value >= 3.25 to be 1
# if not passing a threshold, Accuracy will softmax the logits, and then transform any value < 0.5 to be 0, >= 0.5 to be 1



# PROBABILITIES MODE
acc = Accuracy(mode='probabilities', threshold = 0.6)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary probabilities (float in [0, 1]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='probabilities', threshold = 0.6, is_multilabel = True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel probabilities (float in [0, 1]): (N, C, ...), (N, C, ...)

# in this case, Accuracy will transform any value < 0.6 to be 0, value >= 0.6 to be 1
# if not passing a threshold, Accuracy will transform any value < 0.5 to be 0, >= 0.5 to be 1



# LABELS MODE
acc = Accuracy(mode='labels', threshold = 5)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (int in [0, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='labels', threshold = 5, is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (int in [0, inf]): (N, C, ...), (N, C, ...)

# in the case, Accuracy will transform any value < 5 to be 0, >= 5 to be 1
# must specify a threshold for labels mode




# UNCHANGED MODE
acc = Accuracy(mode='unchanged')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (0s and 1s): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='unchanged’, is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (0s and 1s): (N, C, ...), (N, ...)

# will work like before : raise an error when any value is not 0 or 1
# should not specify a threshold for unchanged mode 

For multiclass data:

### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...)



### after: should not apply threshold to multiclass data
# LABELS MODE
acc = Accuracy(mode='labels')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass labels : (N, ...), (N, ...)
# conflict with _check_type(), since we use y.ndimension() + 1 == y_pred.ndimension() to check for multiclass data


# For now, the following multiclass modes will work like before (argmax):
# PROBABILITIES MODE
acc = Accuracy(mode='probabilities')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass probabilities : (N, C, ...), (N, ...)


# LOGITS MODE
acc = Accuracy(mode='logits')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...) 


# UNCHANGED MODE
acc = Accuracy(mode='unchanged')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...)

sallycaoyu avatar Mar 01 '23 00:03 sallycaoyu

Thanks a lot for the snippet @sallycaoyu !

I have few thoughts about that:

  • would it make sense to introduce mode=multilabels or someother approprate name to hint about multiclass data ?

  • I'm not sure about usefulness of threshold arg. If we want to threshold logits/probas to labels we can use output_transform with any rule/threshold we want:

# binary data:
acc = Accuracy(mode='logits', output_transform=lambda x: (x > 0).to(dtype=torch.long))
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary logits (float in [-inf, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)
  • Maybe we can also drop unchanged mode ?

What do you think ?

vfdev-5 avatar Mar 01 '23 00:03 vfdev-5

@vfdev-5 Thank you very much for the comments!

I agree that we can drop unchanged mode. And I also agree that output_transform can give users more flexibility than threshold, so threshold is not very necessary. Then by default, for:

  • probabilities mode:
    • for binary and multilabel, we can round data to 0 or 1 and compare with y_true
    • for multiclass, we can take argmax for now
  • logits mode:
    • for binary and multilabel, we can do sigmoid then round to 0 or 1 and compare with y_true
    • for multiclass, we can take argmax for now
  • labels mode: I am actually not so sure about how to handle this situation.
    • for multilabel, we can one-hot y_pred and y_true, then compare them
    • for multiclass, we can directly compare y_pred and y_true if they are of the same shape (N, ...), (N, ...) or (N, 1, ...), (N, ...)
    • for binary, how do we map class 0 - class N to 0 or 1?
    • Maybe we should not enable this mode for binary data, but I don't think calling this mode multilabels is a good name for multiclass data because it should be differentiated from the results of multilabel classification. Would nonbinary_labels be a better name for this mode?

sallycaoyu avatar Mar 01 '23 02:03 sallycaoyu