vector-quantize-pytorch icon indicating copy to clipboard operation
vector-quantize-pytorch copied to clipboard

How to Train Residual LFQ

Open ZetangForward opened this issue 1 year ago • 1 comments
trafficstars

Hi, I want to train a Residual LFQ model for audio, and this is my core code:

def _loss_fn(loss_fn, x_target, x_pred, cfg, padding_mask=None):
    if padding_mask is not None:
        padding_mask = padding_mask.unsqueeze(-1).expand_as(x_target)
        x_target = torch.where(padding_mask, x_target, torch.zeros_like(x_target)).to(x_pred.device)
        x_pred = torch.where(padding_mask, x_pred, torch.zeros_like(x_pred)).to(x_pred.device)
        mask_sum = padding_mask.sum()

    if loss_fn == 'l1':
        loss = torch.sum(torch.abs(x_pred - x_target)) / mask_sum
    elif loss_fn == 'l2':
        loss = torch.sum((x_pred - x_target) ** 2) / mask_sum
    elif loss_fn == 'linf':
        residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
        # only consider the residual of the padded part
        masked_residual = torch.where(padding_mask.reshape(x_target.shape[0], -1), residual, torch.zeros_like(residual))
        values, _ = torch.topk(masked_residual, cfg.linf_k, dim=1)
        loss = torch.mean(values)
    else:
        assert False, f"Unknown loss_fn {loss_fn}"

    return loss


def training_step(self, batch, batch_idx):
        quantized, indices, commit_loss = self.model(batch['audio'], batch['padding_mask'])
        quantized_out = self.model.get_output_from_indices(indices)
        reconstruction_loss = _loss_fn('l2', batch['svg_path'], quantized_out, self.cfg, batch['padding_mask'])
        return reconstruction_loss  + commit_loss

model = ResidualLFQ(
        dim = config.lfq.dim,
        codebook_size = config.lfq.codebook_size,
        num_quantizers = config.lfq.num_quantizers
    )

I use reconstruction_loss and commit_loss to jointly update the ResidualLFQ model.

I wonder two things:

  1. Is the reconstruction loss necessary?
  2. Sometimes commitment loss is negative, e.g., -0.02, is this normal? Since I added commit_loss and reconstruction_loss together, it is weird that one loss is positive and one is negative ...

I hope to get some suggestions @kashif @lucidrains Thank you

ZetangForward avatar Jan 19 '24 04:01 ZetangForward

Hi, I want to train a Residual LFQ model for audio, and this is my core code:

def _loss_fn(loss_fn, x_target, x_pred, cfg, padding_mask=None):
    if padding_mask is not None:
        padding_mask = padding_mask.unsqueeze(-1).expand_as(x_target)
        x_target = torch.where(padding_mask, x_target, torch.zeros_like(x_target)).to(x_pred.device)
        x_pred = torch.where(padding_mask, x_pred, torch.zeros_like(x_pred)).to(x_pred.device)
        mask_sum = padding_mask.sum()

    if loss_fn == 'l1':
        loss = torch.sum(torch.abs(x_pred - x_target)) / mask_sum
    elif loss_fn == 'l2':
        loss = torch.sum((x_pred - x_target) ** 2) / mask_sum
    elif loss_fn == 'linf':
        residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
        # only consider the residual of the padded part
        masked_residual = torch.where(padding_mask.reshape(x_target.shape[0], -1), residual, torch.zeros_like(residual))
        values, _ = torch.topk(masked_residual, cfg.linf_k, dim=1)
        loss = torch.mean(values)
    else:
        assert False, f"Unknown loss_fn {loss_fn}"

    return loss


def training_step(self, batch, batch_idx):
        quantized, indices, commit_loss = self.model(batch['audio'], batch['padding_mask'])
        quantized_out = self.model.get_output_from_indices(indices)
        reconstruction_loss = _loss_fn('l2', batch['svg_path'], quantized_out, self.cfg, batch['padding_mask'])
        return reconstruction_loss  + commit_loss

model = ResidualLFQ(
        dim = config.lfq.dim,
        codebook_size = config.lfq.codebook_size,
        num_quantizers = config.lfq.num_quantizers
    )

I use reconstruction_loss and commit_loss to jointly update the ResidualLFQ model.

I wonder two things:

  1. Is the reconstruction loss necessary?
  2. Sometimes commitment loss is negative, e.g., -0.02, is this normal? Since I added commit_loss and reconstruction_loss together, it is weird that one loss is positive and one is negative ...

I hope to get some suggestions @kashif @lucidrains Thank you

same problem, my commitment loss is also negative, have you solved it?

VJJJJJJ1 avatar Feb 24 '25 12:02 VJJJJJJ1