MaskGIT-pytorch
MaskGIT-pytorch copied to clipboard
sample_good() function in transformer.py
Hi!
I think the shape of logits from self.tokens_to_logits
is [batch, 257, 1026] because you defined self.tok_emb = nn.Embedding(args.num_codebook_vectors + 2, args.dim).
However, the number of codebook's embedding is 1024 so that it occurs errors. Haven't you seen these errors during sampling? Did I miss something here?