OpenNMT-py
OpenNMT-py copied to clipboard
Introduce validation post-batch and post-epoch handlers
There are many complaints regarding the lack of BLEU metric logging during training (e.g. #2146, #2129 or #1158). To enable calculation of any validation metric during training I suggest to introduce two validation event handlers:
- post-batch (where we can generate translations for a given batch, e.g. using
onmt.translate.Translator
) - post-epoch (where we can gather all translated sentences, de-tokenize them and calculate corpus-wide BLEU, chrF or any other metric on the whole validset).
The handlers can be defined by a user and assigned to opt
before a call to onmt_train
is made:
opt.valid_post_batch_handler = my_custom_valid_post_batch_handler_function
opt.valid_post_epoch_handler = my_custom_valid_post_epoch_handler_function
onmt_train(opt)
Hey @kagrze This is a frequent complaint indeed. Thanks for contributing! Do you plan on adding some example functions (e.g. translation + BLEU computation) as well as a bit of documentation to this PR? Thanks!
Hey @kagrze This is a frequent complaint indeed. Thanks for contributing! Do you plan on adding some example functions (e.g. translation + BLEU computation) as well as a bit of documentation to this PR? Thanks!
I'll try to prepare an example and documentation.
@francoishernandez this should work:
def detokenize_batch(batch: Tensor, field: Field) -> T.List[str]:
sentences = []
for i in range(batch.shape[1]):
tokens = ""
for t in range(batch.shape[0]):
token = field.vocab.itos[batch[t, i, 0]].replace("▁", " ")
if token == field.pad_token or token == field.eos_token:
break
if token != field.init_token:
tokens += token
sentences.append(tokens.lstrip())
return sentences
def valid_post_batch_handler(
model: NMTModel, batch: Batch, model_opt: Namespace, trans_opt: Namespace
) -> T.Tuple[T.List[str], T.List[str]]:
_, fields, _ = _init_train(copy(model_opt))
scorer = GNMTGlobalScorer.from_opt(trans_opt)
out_stream = StringIO()
translator = Translator.from_opt(
model, fields, trans_opt, copy(model_opt), global_scorer=scorer, out_file=out_stream
)
translator.translate(detokenize_batch(batch.src[0], fields["src"].base_field),
batch_size=batch.batch_size)
out_stream.seek(0)
return [s.replace("▁", "") for s in out_stream], detokenize_batch(batch.tgt, fields["tgt"].base_field)
def valid_post_epoch_handler(post_batch_results: T.List[T.Tuple[T.List[str], T.List[str]]]) -> None:
translations, references = zip(*post_batch_results)
translations = [sentence for batch in translations for sentence in batch]
references = [sentence for batch in references for sentence in batch]
bleu = sacrebleu.corpus_bleu(translations, [references])
logger.info(f"BLEU score: {bleu.score}")
opt.valid_post_batch_handler = partial(valid_post_batch_handler,
model_opt=copy(opt),
trans_opt=trans_opts)
opt.valid_post_epoch_handler = valid_post_epoch_handler
onmt_train(model_opt)
closing now it's merged with #2198