vector-quantize-pytorch icon indicating copy to clipboard operation
vector-quantize-pytorch copied to clipboard

FSQ Oddness

Open zaptrem opened this issue 9 months ago • 5 comments
trafficstars

VQPytorch's FSQ with symmetry on and noise dropping set to 0.5 seems to perform significantly better than the reference implementation in recon loss with the same settings, so I set out to figure out why suspecting one of the two impls may be broken.

First

    def symmetry_preserving_bound(self, z):
        """
        QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1
        """
        levels_minus_1 = (self._levels - 1)
        scale = 2.0 / levels_minus_1
        bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
        return scale * bracket - 1.0

This simplifies to this.

    def symmetry_preserving_bound(self, z):
        return torch.tanh(z) + 1.0 / (self._levels - 1)

Second, this version doesn't seem to match what they do in the reference implementation.

You do:

symmetry_preserved_bound = torch.tanh(z) + 1.0 / (self._levels - 1)
rounded = round_ste(symmetry_preserved_bound) / (self._levels // 2)

They do (also simplified):

dfsq_scale_shift = (torch.tanh(z)/self.scale + 1) * (self._levels - 1) / 2
rounded = round_ste(dfsq_scale_shift)
dfsq_inverse_scale_shift = (rounded * self.scale * 2 / (self._levels - 1)) - self.scale

Third, the noise scaling is slightly different, idk how much this matters:

Yours:

offset = (torch.rand_like(z) - 0.5) / (self._levels // 2)
quantized = torch.where(offset_mask, unquantized + offset, quantized)

Theirs

offset = (torch.rand_like(z) - 0.5) * (self.scale * 2 / (self._levels - 1))
quantized = torch.where(mask, quantized, z + offset)

Fourth, you pass through the non-tanh'ed input to the noise dropout quantization portion, which could allow it to scale arbitrarily to fight off the noise.

            offset_mask = torch.bernoulli(
                torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
            ).bool().expand_as(z)
            
            offset = (torch.rand_like(z) - 0.5) / half_width
            quantized = torch.where(offset_mask, unquantized + offset, quantized)

I suspect the last one is the cause of the performance difference and will check later.

zaptrem avatar Feb 06 '25 22:02 zaptrem