sru icon indicating copy to clipboard operation
sru copied to clipboard

Is it possible to add cross attention aka encoder outputs to SRU++ ?

Open hadaev8 opened this issue 3 years ago • 8 comments

hadaev8 avatar Feb 26 '21 13:02 hadaev8

Hi @hadaev8

At the moment, we haven't implemented a SRU++ "decoder" in which there are both self attention and cross attention. There are two options you could choose:

  1. You can feed in memory to SRU++ and it will treat it as extra context to attend. In other words, you can do sth like:
# encoder
enc_output, enc_hidden, _ = SRUpp_encoder(enc_input, pad_mask=is_padding_mask_enc)

# decoder
memory = [output] * num_dec_layers  # a list of tensor of size (length, batch size, d)
dec_output, dec_hidden, _ = SRUpp_decoder(dec_input,
                                          pad_mask=is_padding_mask_dec,
                                          attn_mask=attention_mask,   # (dec_length, src_length + dec_length)
                                          memory=memory,
                                          memory_mask_pad=...)

Note we are assuming all input & hidden dimensions are d here.

  1. You can customize a SRUpp decoding layer. See here: https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L907-L911

taoleicn avatar Feb 27 '21 17:02 taoleicn

@taoleicn The first option seems like the usual decoder it transformer. It will attend to self outputs and to memory inputs, right?

Where can I find the transform_module definition?

hadaev8 avatar Feb 27 '21 18:02 hadaev8

@hadaev8 yes and no.

Yes in the sense that within each SRU++ layer, the layer will attend to both self outputs and the memory inputs. No in the sense that in a transformer decoder, there are two attention sub-layers. One is used only for self attention, and another one is used only for cross attention. In option 1, what would happen is the memory tensor will first be concatenated with the self outputs from the previous layer, and then only one attention is applied. See https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L791-L793

Re: transform_module definition: https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L90-L94

how SRUpp set transform_module as the attention sub-module: https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L1019-L1028 https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L1046

forward method of SRUppCell: https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L907-L911

taolei87 avatar Feb 27 '21 18:02 taolei87

@taolei87 Do I understood correctly what this expects one memory vector instead of sequence?

hadaev8 avatar Mar 01 '21 16:03 hadaev8

@hadaev8 i'm not sure i follow. can you elaborate more on your question?

taoleicn avatar Mar 01 '21 17:03 taoleicn

@taolei87 What is expected size of memory tensor?

hadaev8 avatar Mar 01 '21 18:03 hadaev8

It is a 3-dimensional tensor (memory_seq_len, batch_size, hidden_size). See an illustration below:

SRUppCell interface

SRUpp module takes a list of memory tensors (one for each sub-layer), and SRUppCell takes a single memory tensor. https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L1088 https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L771

I updated the pseudo code in the previous reply for a correction.

taoleicn avatar Mar 01 '21 18:03 taoleicn

@taolei87 Now I got it. I will try it as it is but have a feeling it's not a good idea to concat self and cross attentions under one softmax. Any plans for adding more common cross attention? Also, how it should work in inference?

Spotted this thing: https://github.com/asappresearch/sru/blob/3.0.0-dev/sru/modules.py#L158

hadaev8 avatar Mar 01 '21 20:03 hadaev8