fairseq icon indicating copy to clipboard operation
fairseq copied to clipboard

How do I truncate the source when using `fairseq-interactive`?

Open erip opened this issue 3 years ago • 2 comments

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

How do I truncate the source to the maximum source positions length when translating interactively?

Code

What have you tried?

Adding the --truncate-source flag does nothing and adding --ignore-invalid-size-input-valid-test doesn't seem to do what I want (skipping the entire example, not translating it at all).

What's your environment?

  • fairseq Version (e.g., 1.0 or main): main
  • PyTorch Version (e.g., 1.0) 1.11
  • OS (e.g., Linux): Linux
  • How you installed fairseq (pip, source): source
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

erip avatar Jun 22 '22 15:06 erip

I searched "truncate_source" in the fairseq repository and it exists in only def load_dataset for TranslationTask. This is used by fairseq-train and generate, While fairseq-interactive.py prepares datasets by dataset=task.build_dataset_for_inference(tokens, lengths, constraints=constraints_tensor) , this means you either edit this method to:

# I have not tested this code. If anything is wrong, please tell me 
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
  # !NOTE: src_lengths is wrong if you truncate and I did not fix it here.
  #             though it should have no impact during inference. Lengths are used to compute loss.
  print("length before truncation:",  src_tokens[0].size() )
  _eos=self.source_dictionary.eos()
  src_dataset = AppendTokenDataset(
                  TruncateDataset(
                      StripTokenDataset(src_tokens, _eos), # this may be wrong
                      self.cfg.max_source_positions - 1,
                  ),
                  _eos,
              )
  ans=LanguagePairDataset(
              src_dataset,
              src_lengths,
              self.source_dictionary,
              tgt_dict=self.target_dictionary,
              constraints=constraints,
          )
  print("length after truncation:",  ans[0]["source"].size() )
  return ans

or you copy translation.py, make the new class as :

@register_task("{your new class name}", dataclass={that new class's Config class, if no new  arguments you can use TranslationConfig})
class YourNewTask(TranslationTask,FairseqTask):
    only overwrite/write method you need

and edit it so you do not corrupt your local fairseq. import this new task class by adding --user-dir

--

Another option is to copy/edit fairseq/data/dicitonray.py. It has a def encode_line and conveniently enough, there is an i referring the token count in a sentence. Change nwords to your max length and add a break if the i is big enough. Downside is fairseq dicitonary has no command line arguments, you have to mannually write down the length.

gmryu avatar Jun 23 '22 04:06 gmryu

I ended up using a cut sandwich with spm_{encode,decode} bread. This is an awful hack, but it's unblocked for now.

erip avatar Jun 26 '22 02:06 erip