setfit
setfit copied to clipboard
Question: How the number of categories affect the training and accuracy?
I have found that increasing the number of categories reduce the accuracy results. Has anyone studied how the increased number of samples per category affect the results?
Hello!
Many papers regarding few-shot (i.e. not a lot of training samples per class) methods consider K-shot learning in their results, where K is some fixed integer, commonly 8 or 16 and sometimes 4, 32 or 64. The SetFit paper is no exception.One ubiquitous finding is that a higher K always results in higher performance. That is assuming that K stays within reason, i.e. K is sufficiently low that we may still speak of "few-shot".
See the following screenshot from the SetFit paper for an example:

As you can see, the performance universally increases as K increases from 8 to 64.
Additionally, this can be verified by using the very first script from the README and modifying the num_samples as if it's K and plotting the results. I've done exactly that.
Click to see the script
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from matplotlib import pyplot as plt
from setfit import SetFitModel, SetFitTrainer, sample_dataset
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")
# Set plot interactive mode
plt.ion()
num_samples_metrics = []
for num_samples in range(1, 20):
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=num_samples)
eval_dataset = dataset["validation"]
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for contrastive learning
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
# Track the number of samples VS the accuracy ratio
num_samples_metrics.append((num_samples, metrics["accuracy"]))
# Plot the performance
fig, ax = plt.subplots()
ax.plot(*zip(*num_samples_metrics))
fig.suptitle("Effect of K in K-shot learning for binary\ntext classification task with dataset `sst2`")
ax.set_xlabel("K (number of samples from each class)")
ax.set_ylabel("accuracy (ratio)")
plt.show(block=True)
Note that this script samples num_samples elements from the much larger train_dataset to simulate a situation with little training data. This means that especially at lower K, the quality of the few training samples that are sampled have a large impact on the performance, which is likely what explains the large drop from K=6 to K=7. As a result, taking averages over multiple runs for the same K would be required to get a smoother and more useful graph, but I don't have time for that. Beyond that, this graph should show my point decently well regardless.
As you can see, moving from 1 or 2 samples to e.g. 5 will cause notable improvements, while increases from e.g. 16 to 18 won't have as much of an influence. I suspect that the shape of this graph will be the same for all problems and datasets, but that the "sweet spot" of labelling time to performance gain will differ depending on the situation. As a result, assuming that you have e.g. 8 samples per class, it may be interesting to compute the performance if you plotted the performance when you "pretend" to have less samples. The slope of the resulting graph may give you an indication of the performance gain if you labeled another 2 samples per class. (That said, perhaps at that point it's better to not spend time writing a plotting script, but just spend the time on the labelling instead, hah!)
Please recognize that this script and the graph are very rudimentary :)
- Tom Aarsen
Thanks for your answer, Tom. But, I failed to detail my question better. Do you know similar studies regarding how the number of classes affects the overall performance, for instance, we need to increase the samples if more classes are added.
I'm currently training a setfit model with 4500 classes, 10 samples per class (using a proprietary dataset).
I think it is still generating pairs though? I just see endless tqdm bars haha
I can share the end accuracy once it gets there :)
Do you know similar studies regarding how the number of classes affects the overall performance, for instance, we need to increase the samples if more classes are added.
This is a good question, @rubensmau. I suspect that it's also impossible to answer in generality as it depends on how well the classes are separated. SetFit tries to organize the embeddings belonging to different classes so that they are separated well, so I would expect the number of classes to be relatively stable as long as the text separates the classes well.
Personally I don't know of any dataset-independent studies like that. You can cook examples where you don't need to increase your samples with the number of classes and others where you do need to do so. 🤔
I've noticed that an increase in classes makes it harder for the SetFit model to properly separate the classes in the embedding space. My experiments have shown that datasets with more classes generally improve in performance more slowly when more data is provided than datasets with fewer classes. For example, I have seen binary classification tasks where increasing from 16 to 32 labels per class gives marginal improvements, while classification tasks with 5 labels do increase in performance fairly significantly when moving from 16 to 32.
In fact, I can run an experiment with exactly that:
python .\scripts\setfit\run_fewshot.py --datasets sst2 sst5 --sample_sizes 2 4 8 16 32 --batch_size 64
This results in:
| dataset | measure | 2_avg | 4_avg | 8_avg | 16_avg | 32_avg |
|---|---|---|---|---|---|---|
| sst2 | accuracy | 71.5% (9.1) | 77.2% (5.6) | 86.2% (3.3) | 90.5% (0.8) | 91.0% (0.9) |
| sst5 | accuracy | 32.9% (2.5) | 38.5% (2.6) | 42.6% (2.6) | 46.2% (1.8) | 48.1% (1.3) |
- Each of these experiments were ran 10 times, and the average accuracies and standard deviations are shown.
- Tom Aarsen
Hi, Sorry to interrupt guys, but I am facing a similar problem, I am dealing with multi class classification, I have my dataset that has 100 categories. The minimum number of examples per category is 30,
- First problem is how do I split dataset into train and test, making sure I am using k-shot learning and have enough examples for test as well and also keeping in mind of class imbalance? how many examples per category I should have in test data?
- The other problem is I want to use the top-5 accuracy evaluation metric, that is the prediction is correct when top-5 predicted labels contain the target category
My dataset is quite huge and I will probably sub-sample it, it has 15k examples and 401 categories. I am planning to experiment with top 100 categories. maximum number of examples per category is around 700, while minimum is 30.
I'm currently training a setfit model with 4500 classes, 10 samples per class (using a proprietary dataset).
I think it is still generating pairs though? I just see endless tqdm bars haha
I can share the end accuracy once it gets there :)
Hi @logan-markewich - did you get useful results out of this? It's very similar to what I am trying but I get ~40%+ accuracy on the training data and ~0% accuracy on the evaluation data.
@grofte yea it never worked well for me either. I think the dataset is just too big haha