ignite
ignite copied to clipboard
Possible improvements for Accuracy
In full detail the feature request is described here, below is a quick recap.
There are two inconveniences I experience with the current interface of Accuracy.
1. Inconsistent input format for binary classification and multiclass problems
In the first case, Accuracy expects labels as input, whilst in the second case it expects probabilities/logits. I am a somewhat experienced Ignite user and I still get confused by this behavior.
2. No shortcuts for saying "I want to pass logits/probabilities as input"
In practice, I have never used Accuracy in the following manner for binary classification:
accuracy = Accuracy()
Instead, I always do one of the following:
accuracy = Accuracy(transform=lambda x: torch.round(torch.sigmoid(x)))
# either
accuracy = Accuracy(transform=lambda x: torch.round(x))
Suggested solution for both problems: let the user explicitly say in which form input will be passed:
import enum
class Accuracy(...):
class Mode(enum.Enum):
LABELS = enum.auto()
PROBABILITIES = enum.auto()
LOGITS = enum.auto()
def __init__(self, mode=Mode.LABELS, ...):
...
The suggested interface can be also extended to support custom thresholds by adding the __call__
method to the Mode class.
@WeirdKeksButtonK I really appreciate this API ! Thank you very much 👍🏻
Hello, I believe I could be assigned to this issue, since I have a PR for it
@vcarpani sure ! On Github we can not assign any user to the issue but only those from the project or who participated in the conversation here.
Hi everyone, I would like to try improving this issue.
Sure @sallycaoyu , please check also all related PRs and mentions.
For now, I am trying to finish implementing a binary_mode
for binary and multilabel types to transform probabilities and logits into 0s and 1s as this PR has done. And if that works well, then I can consider how to add more flexibility to multiclass like issue #822 suggests.
Does that sound like a good plan? Or would your like Ignite to have a mode
similar to what this issue suggests, i.e., mode
in one of [binary, multiclass, multilabel] instead of one of [unchanged, probabilities, logits]? The former way will lead to more modifications to what we have right now, like removing is_multilabel
and replacing it with mode
for Accuracy, Precision, Recall, ClassificationReport, because now multilabel
will be one option of mode
.
@sallycaoyu thanks for the update! I think we can continue with mode
as [unchanged, probabilities, logits, labels?]
.
Can you please sketch up with code snippets new API usage, emphasizing on "before" and "after". For example:
### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits
### after
acc = Accuracy(mode=Accuracy.LOGITS)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C), (N, )
acc = Accuracy(mode=Accuracy.LABELS)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass labels : (N, ), (N, )
etc
Sure! Suppose we have:
class Accuracy
def __init__(
self,
output_transform: Callable = lambda x: x,
is_multilabel: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
mode: str = 'unchanged',
threshold: Union[float, int] = 0.5
)
.....
Then, for binary and multilabel data:
### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (0s and 1s) : (N, ...), (N, ...), or (N, 1, ...), (N, ...)
acc = Accuracy(is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (0s and 1s) : (N, C, ...), (N, C, ...)
### after
# LOGITS MODE
acc = Accuracy(mode='logits', threshold = 3.25)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary logits (float in [-inf, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)
acc = Accuracy(mode='logits', threshold = 3.25, is_multilabel = True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel logits (float in [-inf, inf]): (N, C, ...), (N, C, ...)
# in this case, Accuracy will transform any value < 3.25 to be 0, value >= 3.25 to be 1
# if not passing a threshold, Accuracy will softmax the logits, and then transform any value < 0.5 to be 0, >= 0.5 to be 1
# PROBABILITIES MODE
acc = Accuracy(mode='probabilities', threshold = 0.6)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary probabilities (float in [0, 1]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)
acc = Accuracy(mode='probabilities', threshold = 0.6, is_multilabel = True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel probabilities (float in [0, 1]): (N, C, ...), (N, C, ...)
# in this case, Accuracy will transform any value < 0.6 to be 0, value >= 0.6 to be 1
# if not passing a threshold, Accuracy will transform any value < 0.5 to be 0, >= 0.5 to be 1
# LABELS MODE
acc = Accuracy(mode='labels', threshold = 5)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (int in [0, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)
acc = Accuracy(mode='labels', threshold = 5, is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (int in [0, inf]): (N, C, ...), (N, C, ...)
# in the case, Accuracy will transform any value < 5 to be 0, >= 5 to be 1
# must specify a threshold for labels mode
# UNCHANGED MODE
acc = Accuracy(mode='unchanged')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (0s and 1s): (N, ...), (N, ...), or (N, 1, ...), (N, ...)
acc = Accuracy(mode='unchanged’, is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (0s and 1s): (N, C, ...), (N, ...)
# will work like before : raise an error when any value is not 0 or 1
# should not specify a threshold for unchanged mode
For multiclass data:
### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...)
### after: should not apply threshold to multiclass data
# LABELS MODE
acc = Accuracy(mode='labels')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass labels : (N, ...), (N, ...)
# conflict with _check_type(), since we use y.ndimension() + 1 == y_pred.ndimension() to check for multiclass data
# For now, the following multiclass modes will work like before (argmax):
# PROBABILITIES MODE
acc = Accuracy(mode='probabilities')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass probabilities : (N, C, ...), (N, ...)
# LOGITS MODE
acc = Accuracy(mode='logits')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...)
# UNCHANGED MODE
acc = Accuracy(mode='unchanged')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...)
Thanks a lot for the snippet @sallycaoyu !
I have few thoughts about that:
-
would it make sense to introduce
mode=multilabels
or someother approprate name to hint about multiclass data ? -
I'm not sure about usefulness of
threshold
arg. If we want to threshold logits/probas to labels we can useoutput_transform
with any rule/threshold we want:
# binary data:
acc = Accuracy(mode='logits', output_transform=lambda x: (x > 0).to(dtype=torch.long))
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary logits (float in [-inf, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)
- Maybe we can also drop
unchanged
mode ?
What do you think ?
@vfdev-5 Thank you very much for the comments!
I agree that we can drop unchanged
mode. And I also agree that output_transform
can give users more flexibility than threshold
, so threshold
is not very necessary. Then by default, for:
- probabilities mode:
- for binary and multilabel, we can round data to 0 or 1 and compare with y_true
- for multiclass, we can take argmax for now
- logits mode:
- for binary and multilabel, we can do sigmoid then round to 0 or 1 and compare with y_true
- for multiclass, we can take argmax for now
- labels mode: I am actually not so sure about how to handle this situation.
- for multilabel, we can one-hot y_pred and y_true, then compare them
- for multiclass, we can directly compare y_pred and y_true if they are of the same shape (N, ...), (N, ...) or (N, 1, ...), (N, ...)
- for binary, how do we map class 0 - class N to 0 or 1?
- Maybe we should not enable this mode for binary data, but I don't think calling this mode
multilabels
is a good name for multiclass data because it should be differentiated from the results of multilabel classification. Wouldnonbinary_labels
be a better name for this mode?