torchmetrics
torchmetrics copied to clipboard
Support float targets for classification metrics
🚀 Feature
Support float targets (possibly softened one-hot encodings) for classification metrics.
Motivation
Techniques such as mixup or label smoothing use one-hot encoded targets that are then converted to soft labels. Ideally there would still be a way to use these with TM and not get an error. For example, for you could threshold the target then apply an argmax to get a reasonable result.
Pitch
This could be:
- done automatically based on some conditions
- enabled through the passing of a flag to the metric init
- enabled through the passing of some pre-metric transform in which the user can convert their targets to the desired format
Alternatives
Leave as it is, just require users to format preds and targets correctly.
Hi! thanks for your contribution!, great first issue!
@ethanwharris we had this issue already in https://github.com/PyTorchLightning/metrics/issues/74 . How does it differ or which metric does not yet allow this?
@justusschock that issue is for model outputs / predictions, whereas this is for the targets. Currently, all classification metrics require targets that are int or long.
whoops, sorry I somehow overread this. I think the current design is based on the design by the losses in PyTorch (e.g. CrossEntropy), which also requires it to be integers (so that you could reuse the same data format), but I think it makes sense to extend it.
Do you think to simply rely on the dtype here would be safe? And they have to have a "channel" dimension then (ie a dimension on which we can compute the argmax)?
Yeah, I think it could be safe to just do type inference. I would expect that targets in this case would be of the same dtype and shape as the preds. E.g. we are currently running something which uses mixup and BCEWithLogitsLoss where the targets are soft and are required to be the same shape as the preds. But not sure if this is true in general or just overfits to that one use case haha...
I would expect that targets in this case would be of the same dtype and shape as the preds.
I don't think we should make that assumption. This would then not allow integer predictions and float targets (and I'm pretty sure there is a usecase for this anywhere).
Hey, I am pushing this issue also for usability. When using for instance binary cross entropy loss (BCELoss), the targets are required to be float, and when using the same targets with torchmetrics everything needs to be converted to int, which is quite an overhead I guess?
@benjs We discussed this as part of #1001 and likely won't enable this soon (if at all). The reason for that is two-fold:
1.) You may be able to train with soft-labels, but for validation in classification metrics there ultimately has to be a decision for each class. 2.) Due to 1) if we would enable it, all we would do is casting to int internally. We discussed this and chose not to support this as this would make our API more complex (it is already quite complicated to see, which metric supports what kind of targets). So to be more explicit we opted for a rather unified API and supported types and instead have the user do an explicit conversion. Again, this is exactly the same as we would do internally, so it definitely is no additional overhead (It could even be less overhead when converting once and passing to multiple metrics)!
If working on a larger project, there is also the option to subclass metrics like so:
class FloatSupportAccuracy(Accuracy):
def update(pred, target):
return super().update(pred, target.long())