CTranslate2
CTranslate2 copied to clipboard
Strange outputs for Fairseq model with quant noise layer
I have trained fairseq model with product quantization using the paramaters:
--quant-noise-pq 0.1 --quant-noise-pq-block-size 8.
Listing modules from the checkpoint:
encoder
fairseq.models.transformer.transformer_legacy.TransformerModel'>
Module Name: encoder, Module Type: <class 'fairseq.models.transformer.transformer_encoder.TransformerEncoderBase'>
Module Name: encoder.dropout_module, Module Type: <class 'fairseq.modules.fairseq_dropout.FairseqDropout'>
Module Name: encoder.embed_tokens, Module Type: <class 'torch.nn.modules.sparse.Embedding'>
Module Name: encoder.embed_positions, Module Type: <class 'fairseq.modules.sinusoidal_positional_embedding.SinusoidalPositionalEmbedding'>
Module Name: encoder.quant_noise, Module Type: <class 'torch.nn.modules.linear.Linear'>
Parameter Name: weight, Data Type: torch.float32, Shape: torch.Size([512, 512])
tensor([[-0.1855, 0.2285, 0.2178, ..., -0.1128, 0.1709, -0.1030],
[-1.2734, 0.5352, 0.2715, ..., 0.0703, -0.3047, -0.2559],
[ 0.6797, 0.3164, -0.5625, ..., 0.0889, 0.1973, 0.0417],
...,
[-0.0952, 0.4824, 0.7539, ..., -0.5391, -0.1084, -0.1104],
[ 0.6875, -0.2695, 0.2910, ..., 0.6133, -0.1602, 0.1504],
[-0.3027, 1.6953, 0.0197, ..., -0.0488, -0.0125, 0.0181]])
Module Name: encoder.layers, Module Type: <class 'torch.nn.modules.container.ModuleList'>
Module Name: encoder.layers.0, Module Type: <class 'fairseq.modules.transformer_layer.TransformerEncoderLayerBase'>
Module Name: encoder.layers.0.self_attn, Module Type: <class 'fairseq.modules.multihead_attention.MultiheadAttention'>
Module Name: encoder.layers.0.self_attn.dropout_module, Module Type: <class 'fairseq.modules.fairseq_dropout.FairseqDropout'>
Module Name: encoder.layers.0.self_attn.k_proj, Module Type: <class 'torch.nn.modules.linear.Linear'>
Module Name: encoder.layers.0.self_attn.v_proj, Module Type: <class 'torch.nn.modules.linear.Linear'>
....
decoder:
Module Name: decoder.dropout_module, Module Type: <class 'fairseq.modules.fairseq_dropout.FairseqDropout'>
Module Name: decoder.embed_tokens, Module Type: <class 'torch.nn.modules.sparse.Embedding'>
Module Name: decoder.quant_noise, Module Type: <class 'torch.nn.modules.linear.Linear'>
Parameter Name: weight, Data Type: torch.float32, Shape: torch.Size([512, 512])
tensor([[ 1.3438, 0.4590, 0.7266, ..., -0.4648, 0.6719, 0.2988],
[ 0.0347, 0.0053, -0.0280, ..., -0.0107, -0.0116, -0.0019],
[-0.0107, 0.0075, 0.0327, ..., 0.0315, -0.0038, 0.0130],
...,
[ 0.7852, -0.6445, -0.0913, ..., 1.1250, 0.3965, 0.5625],
[ 0.4746, -0.5508, 0.1494, ..., 0.5273, -0.7344, 0.1074],
[-0.5078, -0.5078, -0.9727, ..., 0.8398, 1.1094, 0.4512]])
Module Name: decoder.embed_positions, Module Type: <class 'fairseq.modules.sinusoidal_positional_embedding.SinusoidalPositionalEmbedding'>
Module Name: decoder.layers, Module Type: <class 'torch.nn.modules.container.ModuleList'>
Module Name: decoder.layers.0, Module Type: <class 'fairseq.modules.transformer_layer.TransformerDecoderLayerBase'>
Module Name: decoder.layers.0.dropout_module, Module Type: <class 'fairseq.modules.fairseq_dropout.FairseqDropout'>
This is how the quant_noise layers interact:
Code from Fairseq's transformer_encoder.py and transformer_decoder.py files
encoder:
def forward_embedding(
self, src_tokens, token_embedding: Optional[torch.Tensor] = None
):
# embed tokens and positions
if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
if self.quant_noise is not None:
x = self.quant_noise(x)
return x, embed
decoder:
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
Same checkpoint, when using fairseq_genrerate gives correct translation outputs.
What are the possible fixes for this issue?