SentenceMIM-demo icon indicating copy to clipboard operation
SentenceMIM-demo copied to clipboard

Some slight changes to get it working

Open summerstay opened this issue 5 years ago • 3 comments

I found the PTB files at Mikolov's page.

I defined a tokenizer to work with my own text data: tokenizer = nltk.TreebankWordTokenizer() s = tokenizer.tokenize(sentence)

In loss_fn() I added a line to fix an error I was getting about expecting Long: target = target.type(torch.LongTensor)

summerstay avatar Apr 24 '20 01:04 summerstay

Thanks @summerstay! Can you send me a pull request or send me a snippet of the code with the fix clearly marked?

michalivne avatar Apr 27 '20 16:04 michalivne

def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0, z, split,
                pad_idx=datasets['train'].pad_idx):

        # cut-off unnecessary padding from target, and flatten
        batch_size = logv.shape[0]
        target = target[:, :torch.max(length).item()].contiguous().view(-1) 
       #----------------- NEW LINE BELOW---------------------------------
        target = target.type(torch.LongTensor)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        NLL_loss = nll(logp, target)

summerstay avatar Apr 27 '20 19:04 summerstay

I made a similar change in test.py:

for k, v in samples.items():
        #------------------LINE BELOW HAS CHANGED --------------------------
        samples[k] = torch.stack(v)[:args.num_samples].type(torch.LongTensor)
        

    z, mean, std = model.encode(samples['input'], samples['length'], return_mean=True, return_std=True)
    z = z.detach()
    mean = mean.detach()
    mean_recon, _ = model.inference(z=mean)
    mean_recon = mean_recon.detach()
    z_recon, _ = model.inference(z=z)
    z_recon = z_recon.detach()
    pert, _ = model.inference(z=z + torch.randn_like(z) * args.pert * std)
    pert = pert.detach()

summerstay avatar Apr 27 '20 19:04 summerstay