pytorch-transformer
pytorch-transformer copied to clipboard
Inconsistency in model.decode() and forward method of Decode class
In train.py, we do:
decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
encoder_output: (b, seq_len, d_model) encoder_mask: (B, 1, 1, seq_len) decoder_input: (b, seq_len) decoder_mask: (B,1, seq_len, seq_len)
This calls the decode function in model.py in Transformer class, which takes args in the following order: encoder_output, src_mask(or encoder_mask), decoder_input, tgt_mask(or decoder_mask)
Till here, i could follow. But after this, the forward method of the Decoder class is called which takes input in the following way: X(the decoder_input), encoder_output, encoder_mask, decoder_mask
Is this an issue?