fairseq
fairseq copied to clipboard
How do I truncate the source when using `fairseq-interactive`?
❓ Questions and Help
Before asking:
- search the issues.
- search the docs.
What is your question?
How do I truncate the source to the maximum source positions length when translating interactively?
Code
What have you tried?
Adding the --truncate-source flag does nothing and adding --ignore-invalid-size-input-valid-test doesn't seem to do what I want (skipping the entire example, not translating it at all).
What's your environment?
- fairseq Version (e.g., 1.0 or main): main
- PyTorch Version (e.g., 1.0) 1.11
- OS (e.g., Linux): Linux
- How you installed fairseq (
pip, source): source - Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information:
I searched "truncate_source" in the fairseq repository and it exists in only def load_dataset for TranslationTask.
This is used by fairseq-train and generate,
While fairseq-interactive.py prepares datasets by dataset=task.build_dataset_for_inference(tokens, lengths, constraints=constraints_tensor) ,
this means you either edit this method to:
# I have not tested this code. If anything is wrong, please tell me
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
# !NOTE: src_lengths is wrong if you truncate and I did not fix it here.
# though it should have no impact during inference. Lengths are used to compute loss.
print("length before truncation:", src_tokens[0].size() )
_eos=self.source_dictionary.eos()
src_dataset = AppendTokenDataset(
TruncateDataset(
StripTokenDataset(src_tokens, _eos), # this may be wrong
self.cfg.max_source_positions - 1,
),
_eos,
)
ans=LanguagePairDataset(
src_dataset,
src_lengths,
self.source_dictionary,
tgt_dict=self.target_dictionary,
constraints=constraints,
)
print("length after truncation:", ans[0]["source"].size() )
return ans
or you copy translation.py, make the new class as :
@register_task("{your new class name}", dataclass={that new class's Config class, if no new arguments you can use TranslationConfig})
class YourNewTask(TranslationTask,FairseqTask):
only overwrite/write method you need
and edit it so you do not corrupt your local fairseq.
import this new task class by adding --user-dir
--
Another option is to copy/edit fairseq/data/dicitonray.py. It has a def encode_line and conveniently enough, there is an i referring the token count in a sentence. Change nwords to your max length and add a break if the i is big enough.
Downside is fairseq dicitonary has no command line arguments, you have to mannually write down the length.
I ended up using a cut sandwich with spm_{encode,decode} bread. This is an awful hack, but it's unblocked for now.