SentenceMIM-demo
SentenceMIM-demo copied to clipboard
Some slight changes to get it working
I found the PTB files at Mikolov's page.
I defined a tokenizer to work with my own text data: tokenizer = nltk.TreebankWordTokenizer() s = tokenizer.tokenize(sentence)
In loss_fn() I added a line to fix an error I was getting about expecting Long: target = target.type(torch.LongTensor)
Thanks @summerstay! Can you send me a pull request or send me a snippet of the code with the fix clearly marked?
def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0, z, split,
pad_idx=datasets['train'].pad_idx):
# cut-off unnecessary padding from target, and flatten
batch_size = logv.shape[0]
target = target[:, :torch.max(length).item()].contiguous().view(-1)
#----------------- NEW LINE BELOW---------------------------------
target = target.type(torch.LongTensor)
logp = logp.view(-1, logp.size(2))
# Negative Log Likelihood
NLL_loss = nll(logp, target)
I made a similar change in test.py:
for k, v in samples.items():
#------------------LINE BELOW HAS CHANGED --------------------------
samples[k] = torch.stack(v)[:args.num_samples].type(torch.LongTensor)
z, mean, std = model.encode(samples['input'], samples['length'], return_mean=True, return_std=True)
z = z.detach()
mean = mean.detach()
mean_recon, _ = model.inference(z=mean)
mean_recon = mean_recon.detach()
z_recon, _ = model.inference(z=z)
z_recon = z_recon.detach()
pert, _ = model.inference(z=z + torch.randn_like(z) * args.pert * std)
pert = pert.detach()