OpenNMT-py icon indicating copy to clipboard operation
OpenNMT-py copied to clipboard

Introduce validation post-batch and post-epoch handlers

Open kagrze opened this issue 2 years ago • 3 comments

There are many complaints regarding the lack of BLEU metric logging during training (e.g. #2146, #2129 or #1158). To enable calculation of any validation metric during training I suggest to introduce two validation event handlers:

  • post-batch (where we can generate translations for a given batch, e.g. using onmt.translate.Translator)
  • post-epoch (where we can gather all translated sentences, de-tokenize them and calculate corpus-wide BLEU, chrF or any other metric on the whole validset).

The handlers can be defined by a user and assigned to opt before a call to onmt_train is made:

opt.valid_post_batch_handler = my_custom_valid_post_batch_handler_function
opt.valid_post_epoch_handler = my_custom_valid_post_epoch_handler_function
onmt_train(opt)

kagrze avatar Apr 27 '22 14:04 kagrze

Hey @kagrze This is a frequent complaint indeed. Thanks for contributing! Do you plan on adding some example functions (e.g. translation + BLEU computation) as well as a bit of documentation to this PR? Thanks!

francoishernandez avatar Apr 27 '22 16:04 francoishernandez

Hey @kagrze This is a frequent complaint indeed. Thanks for contributing! Do you plan on adding some example functions (e.g. translation + BLEU computation) as well as a bit of documentation to this PR? Thanks!

I'll try to prepare an example and documentation.

kagrze avatar Apr 27 '22 18:04 kagrze

@francoishernandez this should work:

def detokenize_batch(batch: Tensor, field: Field) -> T.List[str]:
    sentences = []
    for i in range(batch.shape[1]):
        tokens = ""
        for t in range(batch.shape[0]):
            token = field.vocab.itos[batch[t, i, 0]].replace("▁", " ")
            if token == field.pad_token or token == field.eos_token:
                break
            if token != field.init_token:
                tokens += token
        sentences.append(tokens.lstrip())
    return sentences

def valid_post_batch_handler(
    model: NMTModel, batch: Batch, model_opt: Namespace, trans_opt: Namespace
) -> T.Tuple[T.List[str], T.List[str]]:
    _, fields, _ = _init_train(copy(model_opt))
    scorer = GNMTGlobalScorer.from_opt(trans_opt)
    out_stream = StringIO()
    translator = Translator.from_opt(
        model, fields, trans_opt, copy(model_opt), global_scorer=scorer, out_file=out_stream
    )
    translator.translate(detokenize_batch(batch.src[0], fields["src"].base_field),
                         batch_size=batch.batch_size)
    out_stream.seek(0)
    return [s.replace("▁", "") for s in out_stream], detokenize_batch(batch.tgt, fields["tgt"].base_field)

def valid_post_epoch_handler(post_batch_results: T.List[T.Tuple[T.List[str], T.List[str]]]) -> None:
    translations, references = zip(*post_batch_results)
    translations = [sentence for batch in translations for sentence in batch]
    references = [sentence for batch in references for sentence in batch]
    bleu = sacrebleu.corpus_bleu(translations, [references])
    logger.info(f"BLEU score: {bleu.score}")

opt.valid_post_batch_handler = partial(valid_post_batch_handler,
                                       model_opt=copy(opt),
                                       trans_opt=trans_opts)
opt.valid_post_epoch_handler = valid_post_epoch_handler
onmt_train(model_opt)

kagrze avatar May 08 '22 12:05 kagrze

closing now it's merged with #2198

vince62s avatar Jan 11 '23 10:01 vince62s