annotated-transformer
annotated-transformer copied to clipboard
No need for a generator in the EncoderDecoder class
Hi,
Great notebook! Just wanted to mention that there is no need to pass the generator
in the constructor of the EncoderDecoder
class. It makes it a bit confusing as looking at the model description in make_model
method one implies that the generator is part of the model, yet the loss_compute applies the generator again.
Only after digging into EncoderDecoder definition you realize that the generator is not actually used in the model, so the loss computation is actually correct.
Maybe the forward
function in EncoderDecoder
should be
def forward(self, src, tgt, src_mask, tgt_mask):
memory = self.encode(src, src_mask)
res_dec = self.decode(memory, src_mask, tgt, tgt_mask)
return self.generator(res_dec)