pytorch-metric-learning icon indicating copy to clipboard operation
pytorch-metric-learning copied to clipboard

Bug in MPerClassSampler: It does not select all classes and samples in each epoch of training

Open Jalilnkh opened this issue 1 year ago • 1 comments

First of all, I really appreciated this repo. Thank you very much for the repo! However, there is a function will not work logically, in m_per_class_sampler.py for the classes and sample selection: MPerClassSampler. Let's take a look at iter(self) in that class:

class MPerClassSampler(Sampler):
.
.
.

   def __iter__(self):
        idx_list = [0] * self.list_size
        i = 0
        skus = []
        num_iters = self.calculate_num_iters()
        for _ in range(num_iters):
            cf_ff.NUMPY_RANDOM.shuffle(self.labels)
            if self.batch_size is None:
                curr_label_set = self.labels
            else:
                curr_label_set = self.labels[: self.batch_size // self.m_per_class]
            skus.extend(curr_label_set)
            for label in curr_label_set:
                t = self.labels_to_indices[label]
                idx_list[i : i + self.m_per_class] = cf_ff.safe_random_choice(
                    t, size=self.m_per_class
                )
                i += self.m_per_class
        return iter(idx_list)

I checked several times and for every epoch I could not get all samples(images in the images dataset) and classes. I mean we select all images but not from all classes so instead of having possible images from all possible classes we take duplicate images.

So, in training, we might lose half of our data probably and won't be able us it during the whole training time. I propose to fix this issue.

Jalilnkh avatar Feb 15 '24 12:02 Jalilnkh

I can see how it is confusing to have "epochs" but not have the entire dataset be used in the epoch. This main point of this sampler is to give a random class-balanced batch at every iteration.

It is guaranteed to cover all labels after m*num_labels / batch_size iterations. It might not cover all labels or samples within the length of the dataset ("epoch"), but that is not always possible anyway.

You're right that the sampler could be improved by guaranteeing coverage of all samples within a certain number of iterations.

KevinMusgrave avatar Feb 20 '24 16:02 KevinMusgrave