fastai_extensions icon indicating copy to clipboard operation
fastai_extensions copied to clipboard

Specifying size of the images for using with datasets beside CIFAR-10

Open rsomani95 opened this issue 5 years ago • 0 comments

@oguiza thanks for sharing this repo, great code!

To make nb_MixMatch.py fully compatible with datasets beside CIFAR-10, one needs to be able to specify the target image sizes. By default, you cannot but it can be done with a minor tweak to the def mixmatch(...) function as such:

def mixmatch(learn: Learner, ulist: ItemList, num_workers:int=None, size:Union[int,tuple]=64,
             K: int = 2, T: float = .5, α: float = .75, λ: float = 100) -> Learner:

    labeled_data = learn.data
    if num_workers is None: num_workers = 1
    labeled_data.train_dl.num_workers = num_workers
    bs = labeled_data.train_dl.batch_size
    tfms = [labeled_data.train_ds.tfms, labeled_data.valid_ds.tfms]

    ulist = ulist.split_none()
    ulist.train._label_list = partial(MultiTfmLabelList, K=K)
    train_ul = ulist.label_empty().train           # Train unlabeled Labelist
    valid_ll = learn.data.label_list.valid         # Valid labeled Labelist
  # --------------------------------------------------------------------------
    udata = (LabelLists('.', train_ul, valid_ll)
             .transform(tfms, size=size)
             .databunch(bs=min(bs, len(train_ul)),val_bs=min(bs * 2, len(valid_ll)),
                        num_workers=num_workers,dl_tfms=learn.data.dl_tfms,device=device,
                        collate_fn=MultiCollate)
             .normalize(learn.data.stats))
  # --------------------------------------------------------------------------
    learn.data = udata
    learn.callback_fns.append(partial(MixMatchCallback, labeled_data=labeled_data, T=T, K=K, α=α, λ=λ))
    return learn

Just for reference, I'm using a tweaked version of your code here

rsomani95 avatar Jan 18 '20 17:01 rsomani95