CTC-OptimizedLoss icon indicating copy to clipboard operation
CTC-OptimizedLoss copied to clipboard

How to use the o1 loss?

Open teinhonglo opened this issue 1 year ago • 3 comments

Thanks for sharing the codes. Could you provide an example of the o1 loss? I've combined it with the CTC loss as shown in the following code, but it seems the performance has not improved.

log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)

with torch.backends.cudnn.flags(enabled=False):
                loss = nn.functional.ctc_loss(
                    log_probs,
                    flattened_targets,
                    input_lengths,
                    target_lengths,
                    blank=self.config.pad_token_id,
                    reduction=self.config.ctc_loss_reduction,    # default: sum, use_focal_loss=none
                    zero_infinity=self.config.ctc_zero_infinity, # default: false
)

o1_loss = self.o1_loss(log_probs.transpose(0,1),
                    input_lengths,
                    labels,
                    target_lengths,
)

if self.use_o1_loss:
     o1_loss /= batch_size
     loss = 0.01 * loss + 1. * o1_loss

teinhonglo avatar Feb 03 '24 14:02 teinhonglo

Thanks for sharing the codes. Could you provide an example of the o1 loss? I've combined it with the CTC loss as shown in the following code, but it seems the performance has not improved.

log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)

with torch.backends.cudnn.flags(enabled=False):
                loss = nn.functional.ctc_loss(
                    log_probs,
                    flattened_targets,
                    input_lengths,
                    target_lengths,
                    blank=self.config.pad_token_id,
                    reduction=self.config.ctc_loss_reduction,    # default: sum, use_focal_loss=none
                    zero_infinity=self.config.ctc_zero_infinity, # default: false
)

o1_loss = self.o1_loss(log_probs.transpose(0,1),
                    input_lengths,
                    labels,
                    target_lengths,
)

if self.use_o1_loss:
     o1_loss /= batch_size
     loss = 0.01 * loss + 1. * o1_loss

The role of beam search in RNNT is far greater than that of CTC decoding, and I agree with your conclusion on this. Currently, this loss function is merely an experimental practice.

TeaPoly avatar Feb 04 '24 02:02 TeaPoly

Thanks for your response.

I have another question regarding CTC optimization. In your experience, what modification in this repository has been most beneficial for reducing CTC loss?

teinhonglo avatar Feb 14 '24 09:02 teinhonglo

Thanks for your response.

I have another question regarding CTC optimization. In your experience, what modification in this repository has been most beneficial for reducing CTC loss?

Inter-CTC is very useful for deep NN model. And CTC- CRF is useful for small dataset. https://github.com/thu-spmi/CAT

TeaPoly avatar Feb 16 '24 10:02 TeaPoly