MAML-Pytorch
MAML-Pytorch copied to clipboard
Using the Learner object for my project, Loss not behaving at its best
I am writing a blog (I already presented it in this subreddit) and in my last post, I did a performance analysis of MAML. I ran several experiments, basically trying at the Meta-Learning level both SGD and Adam and different (Meta-)LR's, but summarizing when I try Adam with LR=10^-4 the training is too unstable. At the same time, if I modify it to LR=10^-5 the curve is better but it doesn't improve much (basically the Loss function depends way more on the initialization). Do you have ideas on how to overcome this issue? I think I could apply some Batch Normalization but in Meta-Learning samples are problems, and I'm not sure about if Batch Normalization will work in Meta-Learning.
I'll add images from the last Loss function (raw, smoothed and smoothed+zoomed).
My code (also in the post and not necessary to read for the issue, just for support):
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, RandomSampler, SubsetRandomSampler, BatchSampler
import torchvision
import matplotlib.pyplot as plt
omniglot_raw = torchvision.datasets.Omniglot(root="./dataset/omniglot", download=True, transform=torchvision.transforms.ToTensor())
alphabets = omniglot_raw._alphabets
characters = omniglot_raw._characters
num_alphabets = len(alphabets)
num_characters = len(characters)
class MetaSplit:
def __init__(self, ratio, total_num_characters):
self.alphabets = []
self.num_characters = 0
self.min_num_characters = total_num_characters * ratio
self.num_problems = None
metasplits = {'metatrain': MetaSplit(0.7, num_characters),
'metaval': MetaSplit(0.15, num_characters),
'metatest': MetaSplit(0.15, num_characters)}
chars_per_alphabet = {alph: [char.split('/')[0] for char in characters].count(alph) for alph in alphabets}
random.shuffle(alphabets)
current_metasplit = 'metatrain'
switch_metasplit_from = {'metatrain': 'metaval', 'metaval': 'metatest'}
for alphabet in alphabets:
if not metasplits[current_metasplit].num_characters < metasplits[current_metasplit].min_num_characters:
current_metasplit = switch_metasplit_from[current_metasplit]
metasplits[current_metasplit].alphabets.append(alphabet)
metasplits[current_metasplit].num_characters += chars_per_alphabet[alphabet]
for metasplit in metasplits:
metasplits[metasplit].num_problems = 1/2 * sum([chars_per_alphabet[alph]**2 - chars_per_alphabet[alph] for alph in metasplits[metasplit].alphabets])
metabatch_size = 8
num_metabatches = int(metasplits['metatrain'].num_problems / metabatch_size)
class MetaLoader():
"""
"""
def __init__(self, base_dataset, metabatch_size, batch_sizes,
chars_per_alphabet, problem_ratios):
self.base_dataset = base_dataset
self.metabatch_size = metabatch_size
self.batch_sizes = batch_sizes
self.chars_per_alph = chars_per_alphabet
self.problem_ratios = [0.75, 0.15, 0.1]
self.problems_per_alph = {}
self.num_problems = 0
self.__load_quantitative_info__()
self.metasampler = BatchSampler(RandomSampler(range(self.num_problems)),
batch_size=self.metabatch_size,
drop_last=True)
def __load_quantitative_info__(self):
for alphb in self.chars_per_alph:
self.problems_per_alph[alphb] = int((self.chars_per_alph[alphb]**2 -
self.chars_per_alph[alphb]) / 2)
self.num_problems += self.problems_per_alph[alphb]
def __has_reached__(self, idx, ctr, current):
return ctr + current > idx
def __problem_idx_to_samples_idx__(self, problem_idx, alphb,
prbs_on_prev_alphabets,
chars_on_prev_alphabets):
pb_idx_in_alph = problem_idx - prbs_on_prev_alphabets
ichars_in_alphabet = (int(pb_idx_in_alph / self.chars_per_alph[alphb]),
pb_idx_in_alph % self.chars_per_alph[alphb])
ichars = tuple([ich + chars_on_prev_alphabets \
for ich in ichars_in_alphabet])
return [sample_idx for charidx in ichars
for sample_idx in range(charidx * 20, (charidx + 1) * 20)]
def __build_problem_loader_from_samples__(self, samples_idx):
random.shuffle(samples_idx)
train_val_frontier = int(len(samples_idx) * self.problem_ratios[0])
val_test_frontier = int(train_val_frontier +
len(samples_idx) * self.problem_ratios[1])
samples_idx_train = samples_idx[:train_val_frontier]
samples_idx_val = samples_idx[train_val_frontier:val_test_frontier]
samples_idx_test = samples_idx[val_test_frontier:]
train_sampler = BatchSampler(SubsetRandomSampler(samples_idx_train),
batch_size=self.batch_sizes['train'],
drop_last=True)
val_sampler = BatchSampler(SubsetRandomSampler(samples_idx_val),
batch_size=self.batch_sizes['val'],
drop_last=True)
test_sampler = BatchSampler(SubsetRandomSampler(samples_idx_test),
batch_size=self.batch_sizes['test'],
drop_last=True)
loaders = {'train': DataLoader(dataset=self.base_dataset,
batch_sampler=train_sampler),
'val': DataLoader(dataset=self.base_dataset,
batch_sampler=val_sampler),
'test': DataLoader(dataset=self.base_dataset,
batch_sampler=test_sampler)}
return loaders
def __get_problem_loader__(self, problem_idx):
pbs_ctr = 0
chars_ctr = 0
for alphb in self.chars_per_alph:
if not self.__has_reached__(problem_idx, pbs_ctr,
self.problems_per_alph[alphb]):
pbs_ctr += self.problems_per_alph[alphb]
chars_ctr += self.chars_per_alph[alphb]
else:
problem_samples_idx = self.__problem_idx_to_samples_idx__(
problem_idx, alphb, pbs_ctr, chars_ctr)
return self.__build_problem_loader_from_samples__(
problem_samples_idx)
def __iter__(self):
for imetabatch, metabatch in enumerate(self.metasampler):
problem_loaders = []
for problem_idx in metabatch:
problem_loaders.append(self.__get_problem_loader__(problem_idx))
yield problem_loaders
chars_per_alphabet = {split: {alph: [char.split('/')[0] for char in characters].count(alph) for alph in metasplits[split].alphabets} for split in metasplits}
metatrain_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=metabatch_size, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metatrain'], problem_ratios = [0.75, 0.15, 0.1])
metaval_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=metabatch_size, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metaval'], problem_ratios = [0.75, 0.15, 0.1])
metatest_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=1, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metatest'], problem_ratios = [0.75, 0.15, 0.1])
n_epochs = 15
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 10, 5)
self.conv3 = nn.Conv2d(10, 12, 5)
self.conv4 = nn.Conv2d(12, 16, 5)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 2 * 2, 10)
self.fc2 = nn.Linear(10, 1)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.sigmoid(self.fc2(x))
x = x.squeeze()
return x
def process_labels(labels_raw, ref_label):
return (labels_raw == ref_label).float()
def preprocess_inputs(inputs):
return (1- inputs) * 255
def make_step(model, outputs, labels, update_lr, in_weights):
loss = criterion(outputs, labels)
grads = torch.autograd.grad(loss, model.parameters())
out_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grads, in_weights)))
accuracy = (((1 - outputs) < outputs).float() == labels).sum() / outputs.shape[0]
return out_weights, loss, accuracy
def update_model(model, new_weights, param_keys):
for param, param_key in zip(new_weights, param_keys):
model._modules[param_key[0]]._parameters[param_key[1]] = param
toy_metabatch = next(iter(metatrain_loader))
toy_problem_loader = toy_metabatch[0]['train']
toy_problem_loader_val = toy_metabatch[0]['val']
toy_problem_loader_test = toy_metabatch[0]['test']
# Commented out IPython magic to ensure Python compatibility.
class Learner(nn.Module):
"""
"""
def __init__(self, config, imgc, imgsz):
"""
:param config: network config file, type:list of (string, list)
:param imgc: 1 or 3
:param imgsz: 28 or 84
"""
super(Learner, self).__init__()
self.config = config
# this dict contains all tensors needed to be optimized
self.vars = nn.ParameterList()
# running_mean and running_var
self.vars_bn = nn.ParameterList()
for i, (name, param) in enumerate(self.config):
if name is 'conv2d':
# [ch_out, ch_in, kernelsz, kernelsz]
w = nn.Parameter(torch.ones(*param[:4]))
# gain=1 according to cbfin's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
elif name is 'convt2d':
# [ch_in, ch_out, kernelsz, kernelsz, stride, padding]
w = nn.Parameter(torch.ones(*param[:4]))
# gain=1 according to cbfin's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_in, ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[1])))
elif name is 'linear':
# [ch_out, ch_in]
w = nn.Parameter(torch.ones(*param))
# gain=1 according to cbfinn's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
elif name is 'bn':
# [ch_out]
w = nn.Parameter(torch.ones(param[0]))
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
# must set requires_grad=False
running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
self.vars_bn.extend([running_mean, running_var])
elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d',
'flatten', 'reshape', 'leakyrelu', 'sigmoid']:
continue
else:
raise NotImplementedError
def extra_repr(self):
info = ''
for name, param in self.config:
if name is 'conv2d':
tmp = 'conv2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\
# %(param[1], param[0], param[2], param[3], param[4], param[5],)
info += tmp + '\n'
elif name is 'convt2d':
tmp = 'convTranspose2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\
# %(param[0], param[1], param[2], param[3], param[4], param[5],)
info += tmp + '\n'
elif name is 'linear':
tmp = 'linear:(in:%d, out:%d)'%(param[1], param[0])
info += tmp + '\n'
elif name is 'leakyrelu':
tmp = 'leakyrelu:(slope:%f)'%(param[0])
info += tmp + '\n'
elif name is 'avg_pool2d':
tmp = 'avg_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2])
info += tmp + '\n'
elif name is 'max_pool2d':
tmp = 'max_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2])
info += tmp + '\n'
elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn']:
tmp = name + ':' + str(tuple(param))
info += tmp + '\n'
else:
raise NotImplementedError
return info
def forward(self, x, vars=None, bn_training=True):
"""
This function can be called by finetunning, however, in finetunning, we dont wish to update
running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
:param x: [b, 1, 28, 28]
:param vars:
:param bn_training: set False to not update
:return: x, loss, likelihood, kld
"""
if vars is None:
vars = self.vars
idx = 0
bn_idx = 0
for name, param in self.config:
if name is 'conv2d':
w, b = vars[idx], vars[idx + 1]
# remember to keep synchrozied of forward_encoder and forward_decoder!
x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
idx += 2
# print(name, param, '\tout:', x.shape)
elif name is 'convt2d':
w, b = vars[idx], vars[idx + 1]
# remember to keep synchrozied of forward_encoder and forward_decoder!
x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
idx += 2
# print(name, param, '\tout:', x.shape)
elif name is 'linear':
w, b = vars[idx], vars[idx + 1]
x = F.linear(x, w, b)
idx += 2
# print('forward:', idx, x.norm().item())
elif name is 'bn':
w, b = vars[idx], vars[idx + 1]
running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1]
x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
idx += 2
bn_idx += 2
elif name is 'flatten':
# print(x.shape)
x = x.view(x.size(0), -1)
elif name is 'reshape':
# [b, 8] => [b, 2, 2, 2]
x = x.view(x.size(0), *param)
elif name is 'relu':
x = F.relu(x, inplace=param[0])
elif name is 'leakyrelu':
x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
elif name is 'tanh':
x = F.tanh(x)
elif name is 'sigmoid':
x = torch.sigmoid(x)
elif name is 'upsample':
x = F.upsample_nearest(x, scale_factor=param[0])
elif name is 'max_pool2d':
x = F.max_pool2d(x, param[0], param[1], param[2])
elif name is 'avg_pool2d':
x = F.avg_pool2d(x, param[0], param[1], param[2])
else:
raise NotImplementedError
# make sure variable is used properly
assert idx == len(vars)
assert bn_idx == len(self.vars_bn)
return x
def zero_grad(self, vars=None):
"""
:param vars:
:return:
"""
with torch.no_grad():
if vars is None:
for p in self.vars:
if p.grad is not None:
p.grad.zero_()
else:
for p in vars:
if p.grad is not None:
p.grad.zero_()
def parameters(self):
"""
override this function since initial parameters will return with a generator.
:return:
"""
return self.vars
net_config = [
('conv2d', [6, 1, 5, 5, 1, 0]),
('relu', [True]),
('max_pool2d', [2, 2, 0]),
('conv2d', [10, 6, 5, 5, 1, 0]),
('relu', [True]),
('max_pool2d', [2, 2, 0]),
('conv2d', [12, 10, 5, 5, 1, 0]),
('relu', [True]),
('max_pool2d', [2, 2, 0]),
('conv2d', [16, 12, 5, 5, 1, 0]),
('relu', [True]),
('max_pool2d', [2, 2, 0]),
('flatten', []),
('linear', [10, 64]),
('relu', [True]),
('linear', [1, 10]),
('sigmoid', []),
('reshape', [])
]
printlines = []
model = Learner(net_config, imgc=1, imgsz=28)
criterion = nn.BCEWithLogitsLoss()
update_lr = 0.01
meta_lr = 0.00001
n_epochs = 15
n_metaepochs = 2
metaoptimizer = optim.Adam(model.parameters(), lr=meta_lr)
for metaepoch in range(n_metaepochs):
printlines.append('===============================')
printlines.append(f'// Meta-Epoch {metaepoch + 1} //')
printlines.append('===============================')
print('===============================')
print(f'// Meta-Epoch {metaepoch + 1} //')
print('===============================')
for mi, metabatch in enumerate(metatrain_loader, 0): # Meta-step
print(mi)
printlines.append(f'{mi} updates at Meta-Level')
print(f'{mi} updates at Meta-Level')
running_loss = 0.0 # At each meta-step, the loss is reset
# No need to store initial weights
for pi, problem_loaders in enumerate(metabatch, 0): # Problem in the meta-batch
printlines.append(f'- Problem {pi + 1} -')
print(f'- Problem {pi + 1} -')
problem_loader = problem_loaders['train']
problem_loader_val = problem_loaders['val']
ref_label = None
new_weights = model.parameters()
for epoch in range(n_epochs): # Epoch in the problem training
printlines.append(f'Epoch {epoch + 1}')
print(f'Epoch {epoch + 1}')
val_loss = 0.0
val_accuracy = 0.0
for i, data in enumerate(problem_loader, 0): # Step in the problem
inputs_raw, labels_raw = data
inputs = 1 - inputs_raw
outputs = model(inputs, new_weights)
if ref_label is None:
ref_label = labels_raw[0] # On a new problem (1st step) adjust label mapping
labels = process_labels(labels_raw, ref_label)
new_weights, loss, accuracy = make_step(model, outputs, labels, update_lr, new_weights)
# As the prediction is intrinsically done with the new weights, no need to actually update the model at the Learning Level
printlines.append(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')
print(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')
for iv, datav in enumerate(problem_loader_val): # At the end of the training process in an epoch of a problem we compute a whole validation
inputs_rawv, labels_rawv = datav
inputsv = 1 - inputs_rawv
outputsv = model(inputsv, new_weights)
labelsv = process_labels(labels_rawv, ref_label)
lossv = criterion(outputsv, labelsv) # Loss in a validation batch
val_loss += lossv.item()
val_accuracy += (((1 - outputsv) < outputsv).float() == labelsv).sum()
printlines.append(f'Epoch {epoch + 1}, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}') # Loss and accuracy averaged for all validation batches in the problem, displayed after whole validation
print(f'Epoch {epoch + 1}, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}') # Loss and accuracy averaged for all validation batches in the problem, displayed after whole validation
running_loss += lossv # After all epochs (all training process) in a single problem the validation loss is added
# Again, no need to update the model to the initial weights
metastep_loss = running_loss / metabatch_size # The added validation losses of all problems in the metabatch are averaged
metaoptimizer.zero_grad() # We perform gradient descent at the Meta-Level over the averaged validation loss
metastep_loss.backward()
metaoptimizer.step()
if (mi + 1) % 1000 == 0: # Meta-validation performed every 1000 meta-steps
printlines.append('META-VALIDATION STEP:')
print('META-VALIDATION STEP:')
for mbvi, metabatch_val in enumerate(metaval_loader): # Meta-validation meta-step
if (mbvi + 1) % 10 == 0:
printlines.append(f'Validation step {mbvi + 1}')
print(f'Validation step {mbvi + 1}')
for problem_loaders in metabatch_val: # Problem in the meta-validation meta-batch
problem_loader = problem_loaders['train']
problem_loader_val = problem_loaders['val']
ref_label = None
new_weights = model.parameters()
for epoch in range(n_epochs): # Epoch in the problem training
val_loss = 0.0
val_accuracy = 0.0
for i, data in enumerate(problem_loader, 0): # Step in the problem
inputs_raw, labels_raw = data
inputs = 1 - inputs_raw
outputs = model(inputs)
if ref_label is None:
ref_label = labels_raw[0]
labels = process_labels(labels_raw, ref_label)
new_weights, loss, accuracy = make_step(model, outputs, labels, update_lr, new_weights)
# printlines.append(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')
for iv, datav in enumerate(problem_loader_val): # At the end of the training process in an epoch of a problem we compute a whole validation, as in Meta-Train
inputs_rawv, labels_rawv = datav
inputsv = 1 - inputs_rawv
outputsv = model(inputsv)
labelsv = process_labels(labels_rawv, ref_label)
lossv = criterion(outputsv, labelsv)
val_loss += lossv.item()
val_accuracy += (((1 - outputsv) < outputsv).float() == labelsv).sum()
if (mbvi + 1) % 10 == 0:
printlines.append(f'Last epoch, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}') # The Meta-Validation only runs for informative matters, so our goal is to have this at the end of each problem (every 10 steps)
print(f'Last epoch, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}') # The Meta-Validation only runs for informative matters, so our goal is to have this at the end of each problem (every 10 steps)
printlines.append('END OF META-VALIDATION STEP')
print('END OF META-VALIDATION STEP')