vector-quantize-pytorch
vector-quantize-pytorch copied to clipboard
How to Train Residual LFQ
trafficstars
Hi, I want to train a Residual LFQ model for audio, and this is my core code:
def _loss_fn(loss_fn, x_target, x_pred, cfg, padding_mask=None):
if padding_mask is not None:
padding_mask = padding_mask.unsqueeze(-1).expand_as(x_target)
x_target = torch.where(padding_mask, x_target, torch.zeros_like(x_target)).to(x_pred.device)
x_pred = torch.where(padding_mask, x_pred, torch.zeros_like(x_pred)).to(x_pred.device)
mask_sum = padding_mask.sum()
if loss_fn == 'l1':
loss = torch.sum(torch.abs(x_pred - x_target)) / mask_sum
elif loss_fn == 'l2':
loss = torch.sum((x_pred - x_target) ** 2) / mask_sum
elif loss_fn == 'linf':
residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1)
# only consider the residual of the padded part
masked_residual = torch.where(padding_mask.reshape(x_target.shape[0], -1), residual, torch.zeros_like(residual))
values, _ = torch.topk(masked_residual, cfg.linf_k, dim=1)
loss = torch.mean(values)
else:
assert False, f"Unknown loss_fn {loss_fn}"
return loss
def training_step(self, batch, batch_idx):
quantized, indices, commit_loss = self.model(batch['audio'], batch['padding_mask'])
quantized_out = self.model.get_output_from_indices(indices)
reconstruction_loss = _loss_fn('l2', batch['svg_path'], quantized_out, self.cfg, batch['padding_mask'])
return reconstruction_loss + commit_loss
model = ResidualLFQ(
dim = config.lfq.dim,
codebook_size = config.lfq.codebook_size,
num_quantizers = config.lfq.num_quantizers
)
I use reconstruction_loss and commit_loss to jointly update the ResidualLFQ model.
I wonder two things:
- Is the reconstruction loss necessary?
- Sometimes commitment loss is negative, e.g., -0.02, is this normal? Since I added
commit_lossandreconstruction_losstogether, it is weird that one loss is positive and one is negative ...
I hope to get some suggestions @kashif @lucidrains Thank you
Hi, I want to train a Residual LFQ model for audio, and this is my core code:
def _loss_fn(loss_fn, x_target, x_pred, cfg, padding_mask=None): if padding_mask is not None: padding_mask = padding_mask.unsqueeze(-1).expand_as(x_target) x_target = torch.where(padding_mask, x_target, torch.zeros_like(x_target)).to(x_pred.device) x_pred = torch.where(padding_mask, x_pred, torch.zeros_like(x_pred)).to(x_pred.device) mask_sum = padding_mask.sum() if loss_fn == 'l1': loss = torch.sum(torch.abs(x_pred - x_target)) / mask_sum elif loss_fn == 'l2': loss = torch.sum((x_pred - x_target) ** 2) / mask_sum elif loss_fn == 'linf': residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1) # only consider the residual of the padded part masked_residual = torch.where(padding_mask.reshape(x_target.shape[0], -1), residual, torch.zeros_like(residual)) values, _ = torch.topk(masked_residual, cfg.linf_k, dim=1) loss = torch.mean(values) else: assert False, f"Unknown loss_fn {loss_fn}" return loss def training_step(self, batch, batch_idx): quantized, indices, commit_loss = self.model(batch['audio'], batch['padding_mask']) quantized_out = self.model.get_output_from_indices(indices) reconstruction_loss = _loss_fn('l2', batch['svg_path'], quantized_out, self.cfg, batch['padding_mask']) return reconstruction_loss + commit_loss model = ResidualLFQ( dim = config.lfq.dim, codebook_size = config.lfq.codebook_size, num_quantizers = config.lfq.num_quantizers )I use
reconstruction_lossandcommit_lossto jointly update the ResidualLFQ model.I wonder two things:
- Is the reconstruction loss necessary?
- Sometimes commitment loss is negative, e.g., -0.02, is this normal? Since I added
commit_lossandreconstruction_losstogether, it is weird that one loss is positive and one is negative ...I hope to get some suggestions @kashif @lucidrains Thank you
same problem, my commitment loss is also negative, have you solved it?