yoyodyne icon indicating copy to clipboard operation
yoyodyne copied to clipboard

Different sized encoder for TransformerDecoder

Open Adamits opened this issue 10 months ago • 4 comments

It would be convenient to allow the encoder output_size to be different from the TransformerDecoder embedding size. To illustrate the issue with this, the below code snippet

import torch
import math

def generate_square_subsequent_mask(length: int) -> torch.Tensor:
        return torch.triu(torch.full((length, length), -math.inf), diagonal=1)


# INITIALIZE A TRANSFORMER WITH THIS HIDDEN AND EMBEDDING SIZE
hid=128
emb=64
decoder_layer = torch.nn.TransformerDecoderLayer(
    d_model=emb,
    dim_feedforward=hid,
    nhead=2,
    dropout=0.2,
    activation="relu",
    batch_first=True,
)
frank_transformer = torch.nn.TransformerDecoder(
    decoder_layer=decoder_layer,
    num_layers=2,
    norm=torch.nn.LayerNorm(emb),
)

# INITIALIZE TARGETS WITH EMBEDDING SIZE
# AND A FAKE ENCODER OUTPUT WITH HIDDEN SIZE
b = 4
seq_len = 10
target_embedding = torch.randn((b, seq_len, emb))
encoder_hidden = torch.randn(b, seq_len, hid)
target_sequence_length = target_embedding.size(1)
# -> seq_len x seq_len.
causal_mask = generate_square_subsequent_mask(
    seq_len
)
# -> B x seq_len x d_model.
output = frank_transformer(
    target_embedding,
    encoder_hidden,
    tgt_mask=causal_mask,
    # memory_key_padding_mask=source_mask,
    # tgt_key_padding_mask=target_mask,
)

throws:

File "test.py", line 34, in <module>
    output = frank_transformer(
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/transformer.py", line 460, in forward
    output = mod(output, memory, tgt_mask=tgt_mask,
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/transformer.py", line 847, in forward
    x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
  File "torch/nn/modules/transformer.py", line 865, in _mha_block
    x = self.multihead_attn(x, mem, mem,
  File "torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "torch/nn/functional.py", line 5300, in multi_head_attention_forward
    q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
  File "torch/nn/functional.py", line 4836, in _in_projection_packed
    kv_proj = linear(k, w_kv, b_kv)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (40x128 and 64x128)

But if I change the code such that encoder_hidden = torch.randn(b, seq_len, hid) --> encoder_hidden = torch.randn(b, seq_len, emb), then this works fine.

Essentially, we need the self-attention and multihead-attention to expect different input sizes (which may also require the layer norms to change too).

I am putting this up, and will try to work out a solution. The easiest thing for allowing this behavior in yoyodyne would be to either project the encoder output size into the decoder embedding size, or visa versa, but I feel that this changes the architecture more than necessary. Instead, I would like to consider if there is an elegant way to update the sa_block and mha_block such that it does not break other things in the transformer (e.g. layer norm).

Adamits avatar Apr 24 '24 17:04 Adamits