datasets icon indicating copy to clipboard operation
datasets copied to clipboard

AssertionError when using label_cols in to_tf_dataset

Open lehrig opened this issue 2 years ago • 5 comments

Describe the bug

An incorrect AssertionError is raised when using label_cols in to_tf_dataset and the label's key name is label.

The assertion is in this line: https://github.com/huggingface/datasets/blob/2.4.0/src/datasets/arrow_dataset.py#L475

Steps to reproduce the bug

from datasets import load_dataset
from transformers import DefaultDataCollator

dataset = load_dataset('glue', 'mrpc', split='train')

tf_dataset = dataset.to_tf_dataset(
    columns=["sentence1", "sentence2", "idx"],
    label_cols=["label"],
    batch_size=16,
    collate_fn=DefaultDataCollator(return_tensors="tf"),
)

Expected results

No assertion error.

Actual results

AssertionError: in user code:

    File "/opt/conda/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 475, in split_features_and_labels  *
        assert set(features.keys()).union(labels.keys()) == set(input_batch.keys())

Environment info

  • datasets version: 2.4.0
  • Platform: Linux-4.18.0-305.45.1.el8_4.ppc64le-ppc64le-with-glibc2.17
  • Python version: 3.8.13
  • PyArrow version: 7.0.0
  • Pandas version: 1.4.3

lehrig avatar Jul 29 '22 21:07 lehrig

cc @Rocketknight1

lhoestq avatar Jul 31 '22 20:07 lhoestq

Hi @lehrig, this is caused by the data collator renaming "label" to "labels". If you set label_cols=["labels"] in the call it will work correctly. However, I agree that the cause of the bug is not obvious, so I'll see if I can make a PR to clarify things when the collator renames columns.

Rocketknight1 avatar Aug 02 '22 11:08 Rocketknight1

Thanks - and wow, that appears like a strange side-effect of the data collator. Is that really needed?

Why not make it more explicit? For example, extend DefaultDataCollator with an optional property label_col_name to be used as label column; only when it is not provided default to labels (and document that this happens) for backwards-compatibility?

lehrig avatar Aug 02 '22 11:08 lehrig

Haha, I honestly have no idea why our data collators rename "label" (the standard label column name in our datasets) to "labels" (the standard label column name input to our models). It's been a pain point when I design TF data pipelines, though, because I don't want to hardcode things like that - especially in datasets, because the renaming is something that happens purely at the transformers end. I don't think I could make the change in the data collators themselves at this point, because it would break backward compatibility for everything in PyTorch as well as TF.

In the most recent version of transformers we added a prepare_tf_dataset method to our models which takes care of these details for you, and even chooses appropriate columns and labels for the model you're using. In future we might make that the officially recommended way to convert HF datasets to tf.data.Dataset.

Rocketknight1 avatar Aug 02 '22 11:08 Rocketknight1

Interesting, that'd be great especially for clarity. https://huggingface.co/docs/datasets/use_with_tensorflow#data-loading already improved clarity, yet, all those options will still confuse people. Looking forward to those advances in the hope there'll be only 1 way in the future ;)

Anyways, I am happy for the time being with the work-around you provided. Thank you!

lehrig avatar Aug 02 '22 11:08 lehrig