baal icon indicating copy to clipboard operation
baal copied to clipboard

Argument in `test_on_dataset` and `train_and_test_on_datasets` functions to write "val_" metrics instead of "test_"

Open arthur-thuy opened this issue 2 years ago • 2 comments

Is your feature request related to a problem? Please describe. The MetricMixin class only creates "train_" and "test_" metrics in the add_metric method. This works fine when only using a training and test set.

However, when also using a validation set such as in the snippets below, this presents a problem.

for al_step in range(N_ALSTEP):
    _ = wrapper.train_on_dataset(
        train_dataset, optimizer, BATCH_SIZE, use_cuda=use_cuda
    )
    _ = wrapper.test_on_dataset(val_dataset, BATCH_SIZE)
    _ = wrapper.test_on_dataset(test_dataset, BATCH_SIZE)
    metrics = wrapper.get_metrics()
    # Label the next most uncertain items.
    if not active_loop.step():
        # We're done!
        break
for al_step in range(N_ALSTEP):
    _ = wrapper.train_and_test_on_datasets(
        train_dataset, val_dataset, optimizer, BATCH_SIZE, use_cuda=use_cuda
    )
    _ = wrapper.test_on_dataset(test_dataset, BATCH_SIZE)
    metrics = wrapper.get_metrics()
    # Label the next most uncertain items.
    if not active_loop.step():
        # We're done!
        break

Here, the true validation metrics are recorded as "test_" and are later overwritten by the true test metrics also recorded in "test_".

Describe the solution you'd like It would be nice if the test_on_dataset and train_and_test_on_datasets functions have an argument to specify which metric is written ("val_" or "test_").

Describe alternatives you've considered A simple but cumbersome solution is to create a dict and copy all the "test_" metrics corresponding to the true validation metrics in the dict as "val_", as follows:

trainval_hist = wrapper.train_and_test_on_datasets(...)
trainval_last = trainval_hist[-1]  # NOTE: take log at last epoch
metrics[len(active_set)] = {
    "train_loss": trainval_last["train_loss"],
    "train_accuracy": trainval_last["train_accuracy"],
    "dataset_size": len(active_set),
    "epochs_trained": len(trainval_hist),
    "val_loss": trainval_last["test_loss"],
    "val_accuracy": trainval_last["test_accuracy"],
}

Additional context /

arthur-thuy avatar Jun 23 '23 07:06 arthur-thuy

That make sense! Something like

wrapper.train_and_test_on_datasets(eval_set='val')?

For backward compatibility, we would still keep test as the default.

What do you think?

Dref360 avatar Jul 17 '23 17:07 Dref360

That would be a good solution in my opinion!

arthur-thuy avatar Jul 17 '23 18:07 arthur-thuy