transformer-xl
transformer-xl copied to clipboard
How to retrieve the output word predicted by the LM?
Greetings,
I deleted the old question on the input&vocabulary, I solved that on my own
The remaining doubt is: How to retrieve the output word predicted by the LM?
I can surmise that one has to extract the softmax(logits)
output specified in model.transformer
and then randomly sample a vocabulary index from the distribution, converting it to a word.
Is that correct? Are there any better ways?
For the sake of anybody else coming across this issue: I have now noticed that Issue 49 has a similar question. The answer includes a script that uses the pretrained models found in the pytorch-pretrained-BERT library
You can use my code, for the projected softmax:
if compute_full_outp == True:
out_full_logps = [head_logprob[:, :self.cutoffs[0]]]
offset = 0
cutoff_values = [0] + self.cutoffs
for i in range(1, len(cutoff_values) - 1):
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
mask_i = (target >= 0) #make it trivial because we want the full thing
#print(mask_i.nonzero())
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
target_i = target.index_select(0, indices_i) - l_idx
head_logprob_i = head_logprob.index_select(0, indices_i)
if i == 0:
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
hidden_i = hidden.index_select(0, indices_i)
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
out_full_logps.append(head_logprob_i[:, -i].view(-1, 1) + tail_logprob_i)
#print('i:', i, 'tail_logprob_i size:', tail_logprob_i.size(), 'head_logprob_i size:', head_logprob_i[:,-i].size())
#logprob_i = head_logprob_i[:, -i] \
# + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
#if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
# nll.index_copy_(0, indices_i, -logprob_i)
#else:
# nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
offset += logprob_i.size(0)
out_full_ps = torch.exp(torch.cat(out_full_logps, dim = 1))
#print(out_full_p.size())
#print('sum of out_full_p:', torch.sum(out_full_p, dim = 1))
return_d['out_full_ps'] = out_full_ps