pytorch-transformer icon indicating copy to clipboard operation
pytorch-transformer copied to clipboard

Inconsistency in model.decode() and forward method of Decode class

Open SwastikGorai opened this issue 1 year ago • 0 comments

In train.py, we do:

decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)

encoder_output: (b, seq_len, d_model) encoder_mask: (B, 1, 1, seq_len) decoder_input: (b, seq_len) decoder_mask: (B,1, seq_len, seq_len)

This calls the decode function in model.py in Transformer class, which takes args in the following order: encoder_output, src_mask(or encoder_mask), decoder_input, tgt_mask(or decoder_mask)

Till here, i could follow. But after this, the forward method of the Decoder class is called which takes input in the following way: X(the decoder_input), encoder_output, encoder_mask, decoder_mask

Is this an issue?

SwastikGorai avatar Jul 12 '24 07:07 SwastikGorai