torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Can we support class weight in the CEWithChunkedOutputLoss class

Open ye-jin-shop opened this issue 1 year ago • 4 comments

I am trying to add weights to the loss function. I think it would be nice to have it in the class function? The original class CrossEntropyLoss has weight as an arg.

ye-jin-shop avatar Oct 02 '24 21:10 ye-jin-shop

This seems like a good idea - thanks @ye-jin-shop!

Just so I fully understand what you're trying to do, can you provide a small code example of how you would be using this weight argument?

joecummings avatar Oct 02 '24 21:10 joecummings

@joecummings Thank you for the response!

For example, if I have three classes [0, 1, 2], and I want to put lower weight on the first class. If I am using CE from torch.nn, I can have

weights = [0.5, 1.0, 1.0]
class_weights = torch.FloatTensor(weights).cuda()
criterion = nn.CrossEntropyLoss(weight=class_weights)

While calculating the loss, we could have different weights on different classes. I wonder for this class CEWithChunkedOutputLoss, we could add the arg weight into https://github.com/pytorch/torchtune/blob/17f0bffb3e544f033ae01c2f5a2a3a6e43f4bc90/torchtune/modules/loss/ce_chunked_output_loss.py#L41. It could be passed from line 30 I think.

ye-jin-shop avatar Oct 02 '24 21:10 ye-jin-shop

This should be easy. We can add *args, **kwargs to the init, and something like this should work, but needs some testing if there is any conflict with ignore_index and reduction

class CEWithChunkedOutputLoss(torch.nn.Module):
  def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100, *args, **kwargs):
          super().__init__()
          self.num_output_chunks = num_output_chunks
          self.ignore_index = ignore_index
          self.args = args
          self.kwargs = kwargs

  
      def compute_cross_entropy(
          self, logits: torch.Tensor, labels: torch.Tensor
      ) -> torch.Tensor:
          """
          Upcast logits to fp32 and compute cross entropy loss.
          """
          return F.cross_entropy(
              logits.float(), labels, ignore_index=self.ignore_index, reduction="sum", *args, **kwargs
          )

i can test something like this next week, but if you want to contribute directly, i would be glad to review your PR. Let me know and i can give some pointers. Otherwise, i will post here when i get back to this

felipemello1 avatar Oct 03 '24 02:10 felipemello1

Hey @felipemello1. I am having other priorities right now. I will leave this to you (not urgent, but nice to have). Thank you for your support!

ye-jin-shop avatar Oct 03 '24 18:10 ye-jin-shop