datasets
datasets copied to clipboard
AssertionError when using label_cols in to_tf_dataset
Describe the bug
An incorrect AssertionError
is raised when using label_cols
in to_tf_dataset
and the label's key name is label
.
The assertion is in this line: https://github.com/huggingface/datasets/blob/2.4.0/src/datasets/arrow_dataset.py#L475
Steps to reproduce the bug
from datasets import load_dataset
from transformers import DefaultDataCollator
dataset = load_dataset('glue', 'mrpc', split='train')
tf_dataset = dataset.to_tf_dataset(
columns=["sentence1", "sentence2", "idx"],
label_cols=["label"],
batch_size=16,
collate_fn=DefaultDataCollator(return_tensors="tf"),
)
Expected results
No assertion error.
Actual results
AssertionError: in user code:
File "/opt/conda/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 475, in split_features_and_labels *
assert set(features.keys()).union(labels.keys()) == set(input_batch.keys())
Environment info
-
datasets
version: 2.4.0 - Platform: Linux-4.18.0-305.45.1.el8_4.ppc64le-ppc64le-with-glibc2.17
- Python version: 3.8.13
- PyArrow version: 7.0.0
- Pandas version: 1.4.3
cc @Rocketknight1
Hi @lehrig, this is caused by the data collator renaming "label" to "labels". If you set label_cols=["labels"]
in the call it will work correctly. However, I agree that the cause of the bug is not obvious, so I'll see if I can make a PR to clarify things when the collator renames columns.
Thanks - and wow, that appears like a strange side-effect of the data collator. Is that really needed?
Why not make it more explicit? For example, extend DefaultDataCollator
with an optional property label_col_name
to be used as label column; only when it is not provided default to labels
(and document that this happens) for backwards-compatibility?
Haha, I honestly have no idea why our data collators rename "label"
(the standard label column name in our datasets) to "labels"
(the standard label column name input to our models). It's been a pain point when I design TF data pipelines, though, because I don't want to hardcode things like that - especially in datasets
, because the renaming is something that happens purely at the transformers
end. I don't think I could make the change in the data collators themselves at this point, because it would break backward compatibility for everything in PyTorch as well as TF.
In the most recent version of transformers
we added a prepare_tf_dataset method to our models which takes care of these details for you, and even chooses appropriate columns and labels for the model you're using. In future we might make that the officially recommended way to convert HF datasets to tf.data.Dataset
.
Interesting, that'd be great especially for clarity. https://huggingface.co/docs/datasets/use_with_tensorflow#data-loading already improved clarity, yet, all those options will still confuse people. Looking forward to those advances in the hope there'll be only 1 way in the future ;)
Anyways, I am happy for the time being with the work-around you provided. Thank you!