fastai_extensions
fastai_extensions copied to clipboard
Specifying size of the images for using with datasets beside CIFAR-10
@oguiza thanks for sharing this repo, great code!
To make nb_MixMatch.py
fully compatible with datasets beside CIFAR-10, one needs to be able to specify the target image sizes. By default, you cannot but it can be done with a minor tweak to the def mixmatch(...)
function as such:
def mixmatch(learn: Learner, ulist: ItemList, num_workers:int=None, size:Union[int,tuple]=64,
K: int = 2, T: float = .5, α: float = .75, λ: float = 100) -> Learner:
labeled_data = learn.data
if num_workers is None: num_workers = 1
labeled_data.train_dl.num_workers = num_workers
bs = labeled_data.train_dl.batch_size
tfms = [labeled_data.train_ds.tfms, labeled_data.valid_ds.tfms]
ulist = ulist.split_none()
ulist.train._label_list = partial(MultiTfmLabelList, K=K)
train_ul = ulist.label_empty().train # Train unlabeled Labelist
valid_ll = learn.data.label_list.valid # Valid labeled Labelist
# --------------------------------------------------------------------------
udata = (LabelLists('.', train_ul, valid_ll)
.transform(tfms, size=size)
.databunch(bs=min(bs, len(train_ul)),val_bs=min(bs * 2, len(valid_ll)),
num_workers=num_workers,dl_tfms=learn.data.dl_tfms,device=device,
collate_fn=MultiCollate)
.normalize(learn.data.stats))
# --------------------------------------------------------------------------
learn.data = udata
learn.callback_fns.append(partial(MixMatchCallback, labeled_data=labeled_data, T=T, K=K, α=α, λ=λ))
return learn
Just for reference, I'm using a tweaked version of your code here