autotrain-advanced icon indicating copy to clipboard operation
autotrain-advanced copied to clipboard

[BUG] text classification failing if num_classes in validation < num_classes in training data

Open Mytrill opened this issue 4 months ago • 2 comments

Prerequisites

  • [X] I have read the documentation.
  • [X] I have checked other issues for similar problems.

Backend

Local

Interface Used

CLI

CLI Command

autotrain --config training.yml

UI Screenshots & Parameters

training.yml:

task: text_classification
base_model: google-bert/bert-base-multilingual-uncased
project_name: products-to-categories-finetuned
log: tensorboard
backend: local

data:
  path: data/ 
  train_split: train # this must be either train.csv or train.json
  valid_split: validate # this must be either validate.csv or validate.json
  column_mapping:
    text_column: name # this must be the name of the column containing the text
    target_column: category_id # this must be the name of the column containing the target

params: # Default values...
  max_seq_length: 512
  epochs: 3
  batch_size: 4
  lr: 2e-5
  optimizer: adamw_torch
  scheduler: linear
  gradient_accumulation: 1
  # mixed_precision: fp16

hub:
  username: ${HF_USERNAME}
  token: ${HF_TOKEN}
  push_to_hub: false

Error Logs

INFO     | 2024-10-01 14:22:42 | __main__:train:70 - loading dataset from disk
ERROR    | 2024-10-01 14:22:42 | autotrain.trainers.common:wrapper:120 - train has failed due to an exception: Traceback (most recent call last):
  File "/Users/anthony/.pyenv/versions/3.11.2/lib/python3.11/site-packages/autotrain/trainers/common.py", line 117, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/anthony/.pyenv/versions/3.11.2/lib/python3.11/site-packages/autotrain/trainers/text_classification/__main__.py", line 98, in train
    raise ValueError(
ValueError: Number of classes in train and valid are not the same. Training has 1936 and valid has 1064

ERROR    | 2024-10-01 14:22:42 | autotrain.trainers.common:wrapper:121 - Number of classes in train and valid are not the same. Training has 1936 and valid has 1064

Additional Information

Replacing the check with if num_classes_valid > num_classes: (or removing it, because a previous check makes sure that there are no classes in the validation data that are not in the training data) does not seem to cause any additional issues.

Is there a reason for this check? Is it possible to make this change permanent?

Thank you!

Mytrill avatar Oct 01 '24 12:10 Mytrill