setfit icon indicating copy to clipboard operation
setfit copied to clipboard

Question: How the number of categories affect the training and accuracy?

Open rubensmau opened this issue 2 years ago • 8 comments

I have found that increasing the number of categories reduce the accuracy results. Has anyone studied how the increased number of samples per category affect the results?

rubensmau avatar Feb 02 '23 18:02 rubensmau

Hello!

Many papers regarding few-shot (i.e. not a lot of training samples per class) methods consider K-shot learning in their results, where K is some fixed integer, commonly 8 or 16 and sometimes 4, 32 or 64. The SetFit paper is no exception.One ubiquitous finding is that a higher K always results in higher performance. That is assuming that K stays within reason, i.e. K is sufficiently low that we may still speak of "few-shot". See the following screenshot from the SetFit paper for an example: image

As you can see, the performance universally increases as K increases from 8 to 64.

Additionally, this can be verified by using the very first script from the README and modifying the num_samples as if it's K and plotting the results. I've done exactly that.

Click to see the script
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from matplotlib import pyplot as plt

from setfit import SetFitModel, SetFitTrainer, sample_dataset


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

# Set plot interactive mode
plt.ion()

num_samples_metrics = []
for num_samples in range(1, 20):
    # Simulate the few-shot regime by sampling 8 examples per class
    train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=num_samples)
    eval_dataset = dataset["validation"]

    # Load a SetFit model from Hub
    model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

    # Create trainer
    trainer = SetFitTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss_class=CosineSimilarityLoss,
        metric="accuracy",
        batch_size=16,
        num_iterations=20, # The number of text pairs to generate for contrastive learning
        num_epochs=1, # The number of epochs to use for contrastive learning
        column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
    )

    # Train and evaluate
    trainer.train()
    metrics = trainer.evaluate()
    
    # Track the number of samples VS the accuracy ratio
    num_samples_metrics.append((num_samples, metrics["accuracy"]))

# Plot the performance
fig, ax = plt.subplots()
ax.plot(*zip(*num_samples_metrics))
fig.suptitle("Effect of K in K-shot learning for binary\ntext classification task with dataset `sst2`")
ax.set_xlabel("K (number of samples from each class)")
ax.set_ylabel("accuracy (ratio)")
plt.show(block=True)
This results in the following graph:

image Note that this script samples num_samples elements from the much larger train_dataset to simulate a situation with little training data. This means that especially at lower K, the quality of the few training samples that are sampled have a large impact on the performance, which is likely what explains the large drop from K=6 to K=7. As a result, taking averages over multiple runs for the same K would be required to get a smoother and more useful graph, but I don't have time for that. Beyond that, this graph should show my point decently well regardless.

As you can see, moving from 1 or 2 samples to e.g. 5 will cause notable improvements, while increases from e.g. 16 to 18 won't have as much of an influence. I suspect that the shape of this graph will be the same for all problems and datasets, but that the "sweet spot" of labelling time to performance gain will differ depending on the situation. As a result, assuming that you have e.g. 8 samples per class, it may be interesting to compute the performance if you plotted the performance when you "pretend" to have less samples. The slope of the resulting graph may give you an indication of the performance gain if you labeled another 2 samples per class. (That said, perhaps at that point it's better to not spend time writing a plotting script, but just spend the time on the labelling instead, hah!)

Please recognize that this script and the graph are very rudimentary :)

  • Tom Aarsen

tomaarsen avatar Feb 03 '23 08:02 tomaarsen

Thanks for your answer, Tom. But, I failed to detail my question better. Do you know similar studies regarding how the number of classes affects the overall performance, for instance, we need to increase the samples if more classes are added.

rubensmau avatar Feb 03 '23 11:02 rubensmau

I'm currently training a setfit model with 4500 classes, 10 samples per class (using a proprietary dataset).

I think it is still generating pairs though? I just see endless tqdm bars haha

I can share the end accuracy once it gets there :)

logan-markewich avatar Feb 08 '23 17:02 logan-markewich

Do you know similar studies regarding how the number of classes affects the overall performance, for instance, we need to increase the samples if more classes are added.

This is a good question, @rubensmau. I suspect that it's also impossible to answer in generality as it depends on how well the classes are separated. SetFit tries to organize the embeddings belonging to different classes so that they are separated well, so I would expect the number of classes to be relatively stable as long as the text separates the classes well.

Personally I don't know of any dataset-independent studies like that. You can cook examples where you don't need to increase your samples with the number of classes and others where you do need to do so. 🤔

kgourgou avatar Feb 09 '23 19:02 kgourgou

I've noticed that an increase in classes makes it harder for the SetFit model to properly separate the classes in the embedding space. My experiments have shown that datasets with more classes generally improve in performance more slowly when more data is provided than datasets with fewer classes. For example, I have seen binary classification tasks where increasing from 16 to 32 labels per class gives marginal improvements, while classification tasks with 5 labels do increase in performance fairly significantly when moving from 16 to 32.

In fact, I can run an experiment with exactly that:

python .\scripts\setfit\run_fewshot.py --datasets sst2 sst5 --sample_sizes 2 4 8 16 32 --batch_size 64

This results in:

dataset measure 2_avg 4_avg 8_avg 16_avg 32_avg
sst2 accuracy 71.5% (9.1) 77.2% (5.6) 86.2% (3.3) 90.5% (0.8) 91.0% (0.9)
sst5 accuracy 32.9% (2.5) 38.5% (2.6) 42.6% (2.6) 46.2% (1.8) 48.1% (1.3)
  • Each of these experiments were ran 10 times, and the average accuracies and standard deviations are shown.
  • Tom Aarsen

tomaarsen avatar Feb 14 '23 10:02 tomaarsen

Hi, Sorry to interrupt guys, but I am facing a similar problem, I am dealing with multi class classification, I have my dataset that has 100 categories. The minimum number of examples per category is 30,

  • First problem is how do I split dataset into train and test, making sure I am using k-shot learning and have enough examples for test as well and also keeping in mind of class imbalance? how many examples per category I should have in test data?
  • The other problem is I want to use the top-5 accuracy evaluation metric, that is the prediction is correct when top-5 predicted labels contain the target category

My dataset is quite huge and I will probably sub-sample it, it has 15k examples and 401 categories. I am planning to experiment with top 100 categories. maximum number of examples per category is around 700, while minimum is 30.

iHamzaKhanzada avatar Jun 19 '23 09:06 iHamzaKhanzada

I'm currently training a setfit model with 4500 classes, 10 samples per class (using a proprietary dataset).

I think it is still generating pairs though? I just see endless tqdm bars haha

I can share the end accuracy once it gets there :)

Hi @logan-markewich - did you get useful results out of this? It's very similar to what I am trying but I get ~40%+ accuracy on the training data and ~0% accuracy on the evaluation data.

grofte avatar Jul 26 '23 13:07 grofte

@grofte yea it never worked well for me either. I think the dataset is just too big haha

logan-markewich avatar Jul 26 '23 19:07 logan-markewich