fairseq
fairseq copied to clipboard
Forced decoding & decoder score
❓ Forced decoding & decoder score
I'm using the hub interface. It is possible to get the decoder scores of the just generated hypothesis (generated by the model):
model = torch.hub.load(...)
sent_src_enc = model.encode(sent_src)
sent_tgt_enc = model.generate(sent_src_enc, nbest=1)[0]
sent_tgt_score = sent_tgt_enc["score"].item()
Assuming that I already have the source and hypothesis text from some other source, how would I force the decoder to decode the target text and return the logprob? I know of the existence of SequenceScorer and --score-reference but was unable to use them with the hub interface:
scorer = SequenceScorer(model.tgt_dict)
sent_src_enc = model.encode(sent_src)
sent_tgt_enc = model.encode(sent_tgt)
scorer.generate(model.models(), {"net_input": sent_src_enc, "target": sent_tgt_enc}) # ERROR
What's your environment?
- fairseq version: 0.12.2
- PyTorch version: 1.12.1+cu116
- OS: Ubuntu 22.04
- How you installed fairseq: pip
- Python version: 3.10.4
- CUDA/cuDNN version: 11.5
What is the error? Is it like wrong structure of input {"net_input": sent_src_enc, "target": sent_tgt_enc}?
If it is a transformer with a normal translation task, it uses LanguagePairDataset.
From its definition .py, you can see its structure
It creates a prev_output_tokens from target.
In addition, transformer's forward tells it does not look up target but prev_output_tokens. So I guess you may change the key target to prev_output_tokens.
One last thing, prev_output_tokens is already a tensor, not a dictionary.(see decoder's forward)
You better check out what model.encode(sent_tgt) is and take what you need from it. You may use def merge if it helps.
If your model is not fairseq transformer, or you are using a different dataset. The basic flow it the same, find what data is expected and make a proper batch yourself. Good luck!
You can find a sorta relative example in fairseq_cli/eval_lm.py 's eval_lm You may call this method if loss entropy is what you want finally.
You can perform forced decoding with the following script:
#!/usr/bin/env python3
import torch
from fairseq.sequence_scorer import SequenceScorer
from fairseq.models.transformer import TransformerModel
if __name__ == "__main__":
sent_src = "Hello world!"
sent_tgt = "Hallo Welt!"
model = TransformerModel.from_pretrained(...)
scorer = SequenceScorer(model.tgt_dict)
enc_src = model.encode(sent_src)
ref_enc = model.encode(sent_tgt)
# ensure shapes match for reference
prev = torch.LongTensor([model.tgt_dict.eos() for _ in ref_enc]).unsqueeze(0)
net_input = {"net_input": {"src_tokens": enc_src.unsqueeze(0), "src_lengths": [enc_src.shape[0]], "prev_output_tokens": prev}, "target": ref_enc.unsqueeze(0)}
score = scorer.generate(model.models, net_input)
# print log_e prob
print(score[0][0]["score"])