TextGAN-PyTorch icon indicating copy to clipboard operation
TextGAN-PyTorch copied to clipboard

sampling data from the generator

Open dishavarshney082 opened this issue 5 years ago • 2 comments

I could not understand the way you are sampling data from the generator. Basically, how are you creating batches from num_samples?

def sample(self, num_samples, batch_size, start_letter=cfg.start_letter):
    """
    Samples the network and returns num_samples samples of length max_seq_len.
    :return samples: num_samples * max_seq_length (a sampled sequence in each row)
    """
    num_batch = num_samples // batch_size + 1 if num_samples != batch_size else 1
    samples = torch.zeros(num_batch * batch_size, self.max_seq_len).long()

    # Generate sentences with multinomial sampling strategy
    for b in range(num_batch):
        hidden = self.init_hidden(batch_size)
        inp = torch.LongTensor([start_letter] * batch_size)
        if self.gpu:
            inp = inp.cuda()

        for i in range(self.max_seq_len):
            out, hidden = self.forward(inp, hidden, need_hidden=True)  # out: batch_size * vocab_size
            next_token = torch.multinomial(torch.exp(out), 1)  # batch_size * 1 (sampling from each row)
            samples[b * batch_size:(b + 1) * batch_size, i] = next_token.view(-1)
            inp = next_token.view(-1)
    samples = samples[:num_samples]

    return samples

dishavarshney082 avatar Feb 25 '20 18:02 dishavarshney082

sample() is to sample num_samples sentences from the generator G. The CUDA memory would be overflowed if we directly sample num_samples sentences from G. Thus, num_samples sentences are divided into num_batch batches of sentences.

williamSYSU avatar Feb 26 '20 06:02 williamSYSU

Does sampling from the generator means the same as inferencing from the pretrained MLE model? In RelGan the model used for pretraining and for generating samples are different?

dishavarshney082 avatar Feb 26 '20 12:02 dishavarshney082