annotated-transformer icon indicating copy to clipboard operation
annotated-transformer copied to clipboard

No need for a generator in the EncoderDecoder class

Open mkserge opened this issue 2 years ago • 2 comments

Hi,

Great notebook! Just wanted to mention that there is no need to pass the generator in the constructor of the EncoderDecoder class. It makes it a bit confusing as looking at the model description in make_model method one implies that the generator is part of the model, yet the loss_compute applies the generator again.

Only after digging into EncoderDecoder definition you realize that the generator is not actually used in the model, so the loss computation is actually correct.

mkserge avatar Dec 31 '22 23:12 mkserge

Maybe the forward function in EncoderDecoder should be

    def forward(self, src, tgt, src_mask, tgt_mask):
        memory = self.encode(src, src_mask)
        res_dec = self.decode(memory, src_mask, tgt, tgt_mask)
        return self.generator(res_dec)

zh-jp avatar Jan 17 '24 09:01 zh-jp