TextGAN-PyTorch
TextGAN-PyTorch copied to clipboard
sampling data from the generator
I could not understand the way you are sampling data from the generator. Basically, how are you creating batches from num_samples?
def sample(self, num_samples, batch_size, start_letter=cfg.start_letter):
"""
Samples the network and returns num_samples samples of length max_seq_len.
:return samples: num_samples * max_seq_length (a sampled sequence in each row)
"""
num_batch = num_samples // batch_size + 1 if num_samples != batch_size else 1
samples = torch.zeros(num_batch * batch_size, self.max_seq_len).long()
# Generate sentences with multinomial sampling strategy
for b in range(num_batch):
hidden = self.init_hidden(batch_size)
inp = torch.LongTensor([start_letter] * batch_size)
if self.gpu:
inp = inp.cuda()
for i in range(self.max_seq_len):
out, hidden = self.forward(inp, hidden, need_hidden=True) # out: batch_size * vocab_size
next_token = torch.multinomial(torch.exp(out), 1) # batch_size * 1 (sampling from each row)
samples[b * batch_size:(b + 1) * batch_size, i] = next_token.view(-1)
inp = next_token.view(-1)
samples = samples[:num_samples]
return samples
sample() is to sample num_samples sentences from the generator G. The CUDA memory would be overflowed if we directly sample num_samples sentences from G. Thus, num_samples sentences are divided into num_batch batches of sentences.
Does sampling from the generator means the same as inferencing from the pretrained MLE model? In RelGan the model used for pretraining and for generating samples are different?