Can we support class weight in the CEWithChunkedOutputLoss class
I am trying to add weights to the loss function. I think it would be nice to have it in the class function? The original class CrossEntropyLoss has weight as an arg.
This seems like a good idea - thanks @ye-jin-shop!
Just so I fully understand what you're trying to do, can you provide a small code example of how you would be using this weight argument?
@joecummings Thank you for the response!
For example, if I have three classes [0, 1, 2], and I want to put lower weight on the first class. If I am using CE from torch.nn, I can have
weights = [0.5, 1.0, 1.0]
class_weights = torch.FloatTensor(weights).cuda()
criterion = nn.CrossEntropyLoss(weight=class_weights)
While calculating the loss, we could have different weights on different classes. I wonder for this class CEWithChunkedOutputLoss, we could add the arg weight into https://github.com/pytorch/torchtune/blob/17f0bffb3e544f033ae01c2f5a2a3a6e43f4bc90/torchtune/modules/loss/ce_chunked_output_loss.py#L41. It could be passed from line 30 I think.
This should be easy. We can add *args, **kwargs to the init, and something like this should work, but needs some testing if there is any conflict with ignore_index and reduction
class CEWithChunkedOutputLoss(torch.nn.Module):
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100, *args, **kwargs):
super().__init__()
self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index
self.args = args
self.kwargs = kwargs
def compute_cross_entropy(
self, logits: torch.Tensor, labels: torch.Tensor
) -> torch.Tensor:
"""
Upcast logits to fp32 and compute cross entropy loss.
"""
return F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum", *args, **kwargs
)
i can test something like this next week, but if you want to contribute directly, i would be glad to review your PR. Let me know and i can give some pointers. Otherwise, i will post here when i get back to this
Hey @felipemello1. I am having other priorities right now. I will leave this to you (not urgent, but nice to have). Thank you for your support!