autotrain-advanced
autotrain-advanced copied to clipboard
[BUG] text classification failing if num_classes in validation < num_classes in training data
Prerequisites
- [X] I have read the documentation.
- [X] I have checked other issues for similar problems.
Backend
Local
Interface Used
CLI
CLI Command
autotrain --config training.yml
UI Screenshots & Parameters
training.yml:
task: text_classification
base_model: google-bert/bert-base-multilingual-uncased
project_name: products-to-categories-finetuned
log: tensorboard
backend: local
data:
path: data/
train_split: train # this must be either train.csv or train.json
valid_split: validate # this must be either validate.csv or validate.json
column_mapping:
text_column: name # this must be the name of the column containing the text
target_column: category_id # this must be the name of the column containing the target
params: # Default values...
max_seq_length: 512
epochs: 3
batch_size: 4
lr: 2e-5
optimizer: adamw_torch
scheduler: linear
gradient_accumulation: 1
# mixed_precision: fp16
hub:
username: ${HF_USERNAME}
token: ${HF_TOKEN}
push_to_hub: false
Error Logs
INFO | 2024-10-01 14:22:42 | __main__:train:70 - loading dataset from disk
ERROR | 2024-10-01 14:22:42 | autotrain.trainers.common:wrapper:120 - train has failed due to an exception: Traceback (most recent call last):
File "/Users/anthony/.pyenv/versions/3.11.2/lib/python3.11/site-packages/autotrain/trainers/common.py", line 117, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/anthony/.pyenv/versions/3.11.2/lib/python3.11/site-packages/autotrain/trainers/text_classification/__main__.py", line 98, in train
raise ValueError(
ValueError: Number of classes in train and valid are not the same. Training has 1936 and valid has 1064
ERROR | 2024-10-01 14:22:42 | autotrain.trainers.common:wrapper:121 - Number of classes in train and valid are not the same. Training has 1936 and valid has 1064
Additional Information
Replacing the check with if num_classes_valid > num_classes:
(or removing it, because a previous check makes sure that there are no classes in the validation data that are not in the training data) does not seem to cause any additional issues.
Is there a reason for this check? Is it possible to make this change permanent?
Thank you!