pytorch-multi-class-focal-loss
                                
                                
                                
                                    pytorch-multi-class-focal-loss copied to clipboard
                            
                            
                            
                        An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, generalized to the multi-class case.
Multi-class Focal Loss
An (unofficial) implementation of Focal Loss, as described in the RetinaNet paper, https://arxiv.org/abs/1708.02002, generalized to the multi-class case.
It is essentially an enhancement to cross-entropy loss and is useful for classification tasks when there is a large class imbalance. It has the effect of underweighting easy examples.
Usage
- 
FocalLossis annn.Moduleand behaves very much likenn.CrossEntropyLoss()i.e.- supports the 
reductionandignore_indexparams, and - is able to work with 2D inputs of shape 
(N, C)as well as K-dimensional inputs of shape(N, C, d1, d2, ..., dK). 
 - supports the 
 - 
Example usage
focal_loss = FocalLoss(alpha, gamma) .. np, targets = batch out = model(inp) oss = focal_loss(out, targets) 
Loading through torch.hub
This repo supports importing modules through torch.hub. FocalLoss can be easily imported into your code via, for example:
focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='FocalLoss',
	alpha=torch.tensor([.75, .25]),
	gamma=2,
	reduction='mean',
	force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)
Or:
focal_loss = torch.hub.load(
	'adeelh/pytorch-multi-class-focal-loss',
	model='focal_loss',
	alpha=[.75, .25],
	gamma=2,
	reduction='mean',
	device='cpu',
	dtype=torch.float32,
	force_reload=False
)
x, y = torch.randn(10, 2), (torch.rand(10) > .5).long()
loss = focal_loss(x, y)