pytorch-meta icon indicating copy to clipboard operation
pytorch-meta copied to clipboard

one-vs-all sampler - how to get binary targets?

Open mmahsereci opened this issue 3 years ago • 2 comments

Hi, great work on the library!

I am attempting to implement a task sampler that samples 1-vs-all tasks. That is every task is a binary classification problem where the class 0 contains shots of one random label, and the class 1 contains shots of all other classes except the label that represents class 0.

I inherited from MetaDataset and did the following (my question is below the code):

` def __ getitem __(self, index):

if not isinstance(index, int):
    raise ValueError('The index of a `OneVsAllMetaDataset` must be an integer')
    
# create 2 datasets for the task: first one corresponds to label=index, second one contains all other labels
idx_set = [i for i in range(len(self.dataset))]
del idx_set[index]

# Use deepcopy on `Categorical` target transforms, to avoid any side
# effect across tasks.
dataset_one = ConcatTask([self.dataset[index]],
                          1,
                          target_transform=wrap_transform(Categorical(),
                                                             self._copy_categorical_one,
                                                             transform_type=Categorical))
dataset_vs_all = ConcatTask([self.dataset[i] for i in idx_set],
                            1,
                            target_transform=wrap_transform(Categorical(),
                                                            self._copy_categorical_vs_all,
                                                            transform_type=Categorical))

task = ConcatTask([dataset_one, dataset_vs_all],
                  self.num_classes_per_task)

if self.dataset_transform is not None:
    task = self.dataset_transform(task)

return task

`

After applying the ClassSplitter and the BatchMetaDataLoader, I get correct tasks, but I am not happy with the labels. They look like this for a single task with 5 shots:

tensor([[ 0, 0, 0, 0, 0, 211, 727, 613, 198, 435]])

They are the outputs of the Categorical transform. But I want this instead:

tensor([[ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]])

since I want to represent a binary task. Do you have any hint how to fix this?

I would be happy to do a pull request with the code once it is fixed. 1-vs-all samplers are commonly used these days.

Cheers!

mmahsereci avatar Mar 15 '21 15:03 mmahsereci

You probably want to create your own target_transform, instead of the default Categorical target transform. You can take inspiration from Categorical.

If you manage to implement the one-vs-all sampler, I'd be very happy to have a PR!

tristandeleu avatar Mar 18 '21 16:03 tristandeleu

thanks for the hint. that makes sense. I'll give it a go!

mmahsereci avatar Mar 18 '21 17:03 mmahsereci