MAML-Pytorch icon indicating copy to clipboard operation
MAML-Pytorch copied to clipboard

Using the Learner object for my project, Loss not behaving at its best

Open Metabloggism opened this issue 2 years ago • 0 comments

I am writing a blog (I already presented it in this subreddit) and in my last post, I did a performance analysis of MAML. I ran several experiments, basically trying at the Meta-Learning level both SGD and Adam and different (Meta-)LR's, but summarizing when I try Adam with LR=10^-4 the training is too unstable. At the same time, if I modify it to LR=10^-5 the curve is better but it doesn't improve much (basically the Loss function depends way more on the initialization). Do you have ideas on how to overcome this issue? I think I could apply some Batch Normalization but in Meta-Learning samples are problems, and I'm not sure about if Batch Normalization will work in Meta-Learning.

I'll add images from the last Loss function (raw, smoothed and smoothed+zoomed).

Raw

Smoothed

Smoothed+scaled

My code (also in the post and not necessary to read for the issue, just for support):

import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, RandomSampler, SubsetRandomSampler, BatchSampler
import torchvision
import matplotlib.pyplot as plt

omniglot_raw = torchvision.datasets.Omniglot(root="./dataset/omniglot", download=True, transform=torchvision.transforms.ToTensor())


alphabets = omniglot_raw._alphabets
characters = omniglot_raw._characters


num_alphabets = len(alphabets)
num_characters = len(characters)

class MetaSplit:
  def __init__(self, ratio, total_num_characters):
    self.alphabets = []
    self.num_characters = 0
    self.min_num_characters = total_num_characters * ratio
    self.num_problems = None

metasplits = {'metatrain': MetaSplit(0.7, num_characters),
              'metaval': MetaSplit(0.15, num_characters),
              'metatest': MetaSplit(0.15, num_characters)}

chars_per_alphabet = {alph: [char.split('/')[0] for char in characters].count(alph) for alph in alphabets}

random.shuffle(alphabets)

current_metasplit = 'metatrain'
switch_metasplit_from = {'metatrain': 'metaval', 'metaval': 'metatest'}

for alphabet in alphabets:
  if not metasplits[current_metasplit].num_characters < metasplits[current_metasplit].min_num_characters:
    current_metasplit = switch_metasplit_from[current_metasplit]
  metasplits[current_metasplit].alphabets.append(alphabet)
  metasplits[current_metasplit].num_characters += chars_per_alphabet[alphabet]

for metasplit in metasplits:
  metasplits[metasplit].num_problems = 1/2 * sum([chars_per_alphabet[alph]**2 - chars_per_alphabet[alph] for alph in metasplits[metasplit].alphabets])

metabatch_size = 8
num_metabatches = int(metasplits['metatrain'].num_problems / metabatch_size)

class MetaLoader():
    """
    """
    def __init__(self, base_dataset, metabatch_size, batch_sizes, 
                 chars_per_alphabet, problem_ratios):
        self.base_dataset = base_dataset
        self.metabatch_size = metabatch_size
        self.batch_sizes = batch_sizes
        self.chars_per_alph = chars_per_alphabet
        self.problem_ratios = [0.75, 0.15, 0.1]
        self.problems_per_alph = {}
        self.num_problems = 0
        self.__load_quantitative_info__()
        self.metasampler = BatchSampler(RandomSampler(range(self.num_problems)), 
                                        batch_size=self.metabatch_size, 
                                        drop_last=True)
    
    def __load_quantitative_info__(self):
        for alphb in self.chars_per_alph:
            self.problems_per_alph[alphb] = int((self.chars_per_alph[alphb]**2 - 
                                                self.chars_per_alph[alphb]) / 2)
            self.num_problems += self.problems_per_alph[alphb]
    
    def __has_reached__(self, idx, ctr, current):
        return ctr + current > idx
    
    def __problem_idx_to_samples_idx__(self, problem_idx, alphb, 
                                       prbs_on_prev_alphabets, 
                                       chars_on_prev_alphabets):
        pb_idx_in_alph = problem_idx - prbs_on_prev_alphabets
        ichars_in_alphabet = (int(pb_idx_in_alph / self.chars_per_alph[alphb]), 
                                pb_idx_in_alph % self.chars_per_alph[alphb])
        ichars = tuple([ich + chars_on_prev_alphabets \
                        for ich in ichars_in_alphabet])
        return [sample_idx for charidx in ichars 
                for sample_idx in range(charidx * 20, (charidx + 1) * 20)]
    
    def __build_problem_loader_from_samples__(self, samples_idx):

        random.shuffle(samples_idx)

        train_val_frontier = int(len(samples_idx) * self.problem_ratios[0])
        val_test_frontier = int(train_val_frontier + 
                                len(samples_idx) * self.problem_ratios[1])
        
        samples_idx_train = samples_idx[:train_val_frontier]
        samples_idx_val = samples_idx[train_val_frontier:val_test_frontier]
        samples_idx_test = samples_idx[val_test_frontier:]

        train_sampler = BatchSampler(SubsetRandomSampler(samples_idx_train), 
                                     batch_size=self.batch_sizes['train'], 
                                     drop_last=True)
        val_sampler = BatchSampler(SubsetRandomSampler(samples_idx_val), 
                                   batch_size=self.batch_sizes['val'], 
                                   drop_last=True)
        test_sampler = BatchSampler(SubsetRandomSampler(samples_idx_test), 
                                    batch_size=self.batch_sizes['test'], 
                                    drop_last=True)
        loaders = {'train': DataLoader(dataset=self.base_dataset, 
                                       batch_sampler=train_sampler),
                   'val': DataLoader(dataset=self.base_dataset, 
                                       batch_sampler=val_sampler),
                   'test': DataLoader(dataset=self.base_dataset, 
                                       batch_sampler=test_sampler)}
        return loaders

        
    def __get_problem_loader__(self, problem_idx):
        pbs_ctr = 0
        chars_ctr = 0
        for alphb in self.chars_per_alph:
            if not self.__has_reached__(problem_idx, pbs_ctr, 
                                        self.problems_per_alph[alphb]):
                pbs_ctr += self.problems_per_alph[alphb]
                chars_ctr += self.chars_per_alph[alphb]
            else:
                problem_samples_idx = self.__problem_idx_to_samples_idx__(
                    problem_idx, alphb, pbs_ctr, chars_ctr)
                return self.__build_problem_loader_from_samples__(
                    problem_samples_idx)

    def  __iter__(self):
        for imetabatch, metabatch in enumerate(self.metasampler):
            problem_loaders = []
            for problem_idx in metabatch:
                problem_loaders.append(self.__get_problem_loader__(problem_idx))
            yield problem_loaders

chars_per_alphabet = {split: {alph: [char.split('/')[0] for char in characters].count(alph) for alph in metasplits[split].alphabets} for split in metasplits}

metatrain_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=metabatch_size, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metatrain'], problem_ratios = [0.75, 0.15, 0.1])
metaval_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=metabatch_size, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metaval'], problem_ratios = [0.75, 0.15, 0.1])
metatest_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=1, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metatest'], problem_ratios = [0.75, 0.15, 0.1])

n_epochs = 15

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 10, 5)
        self.conv3 = nn.Conv2d(10, 12, 5)
        self.conv4 = nn.Conv2d(12, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 2 * 2, 10)
        self.fc2 = nn.Linear(10, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.sigmoid(self.fc2(x))
        x = x.squeeze()
        return x


def process_labels(labels_raw, ref_label):
  return (labels_raw == ref_label).float()

def preprocess_inputs(inputs):
    return (1- inputs) * 255

def make_step(model, outputs, labels, update_lr, in_weights):
    loss = criterion(outputs, labels)
    grads = torch.autograd.grad(loss, model.parameters())
    out_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grads, in_weights)))
    accuracy = (((1 - outputs) < outputs).float() == labels).sum() / outputs.shape[0]
    return out_weights, loss, accuracy

def update_model(model, new_weights, param_keys):
    for param, param_key in zip(new_weights, param_keys):
        model._modules[param_key[0]]._parameters[param_key[1]] = param

toy_metabatch = next(iter(metatrain_loader))
toy_problem_loader = toy_metabatch[0]['train']
toy_problem_loader_val = toy_metabatch[0]['val']
toy_problem_loader_test = toy_metabatch[0]['test']

# Commented out IPython magic to ensure Python compatibility.
class Learner(nn.Module):
    """

    """

    def __init__(self, config, imgc, imgsz):
        """

        :param config: network config file, type:list of (string, list)
        :param imgc: 1 or 3
        :param imgsz:  28 or 84
        """
        super(Learner, self).__init__()


        self.config = config

        # this dict contains all tensors needed to be optimized
        self.vars = nn.ParameterList()
        # running_mean and running_var
        self.vars_bn = nn.ParameterList()

        for i, (name, param) in enumerate(self.config):
            if name is 'conv2d':
                # [ch_out, ch_in, kernelsz, kernelsz]
                w = nn.Parameter(torch.ones(*param[:4]))
                # gain=1 according to cbfin's implementation
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name is 'convt2d':
                # [ch_in, ch_out, kernelsz, kernelsz, stride, padding]
                w = nn.Parameter(torch.ones(*param[:4]))
                # gain=1 according to cbfin's implementation
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                # [ch_in, ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[1])))

            elif name is 'linear':
                # [ch_out, ch_in]
                w = nn.Parameter(torch.ones(*param))
                # gain=1 according to cbfinn's implementation
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name is 'bn':
                # [ch_out]
                w = nn.Parameter(torch.ones(param[0]))
                self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

                # must set requires_grad=False
                running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
                running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
                self.vars_bn.extend([running_mean, running_var])


            elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d',
                          'flatten', 'reshape', 'leakyrelu', 'sigmoid']:
                continue
            else:
                raise NotImplementedError






    def extra_repr(self):
        info = ''

        for name, param in self.config:
            if name is 'conv2d':
                tmp = 'conv2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\
#                       %(param[1], param[0], param[2], param[3], param[4], param[5],)
                info += tmp + '\n'

            elif name is 'convt2d':
                tmp = 'convTranspose2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\
#                       %(param[0], param[1], param[2], param[3], param[4], param[5],)
                info += tmp + '\n'

            elif name is 'linear':
                tmp = 'linear:(in:%d, out:%d)'%(param[1], param[0])
                info += tmp + '\n'

            elif name is 'leakyrelu':
                tmp = 'leakyrelu:(slope:%f)'%(param[0])
                info += tmp + '\n'


            elif name is 'avg_pool2d':
                tmp = 'avg_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2])
                info += tmp + '\n'
            elif name is 'max_pool2d':
                tmp = 'max_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2])
                info += tmp + '\n'
            elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn']:
                tmp = name + ':' + str(tuple(param))
                info += tmp + '\n'
            else:
                raise NotImplementedError

        return info



    def forward(self, x, vars=None, bn_training=True):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 1, 28, 28]
        :param vars:
        :param bn_training: set False to not update
        :return: x, loss, likelihood, kld
        """

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        for name, param in self.config:
            if name is 'conv2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'linear':
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
                # print('forward:', idx, x.norm().item())
            elif name is 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1]
                x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
                idx += 2
                bn_idx += 2

            elif name is 'flatten':
                # print(x.shape)
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])

            else:
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)


        return x


    def zero_grad(self, vars=None):
        """

        :param vars:
        :return:
        """
        with torch.no_grad():
            if vars is None:
                for p in self.vars:
                    if p.grad is not None:
                        p.grad.zero_()
            else:
                for p in vars:
                    if p.grad is not None:
                        p.grad.zero_()

    def parameters(self):
        """
        override this function since initial parameters will return with a generator.
        :return:
        """
        return self.vars

net_config = [
        ('conv2d', [6, 1, 5, 5, 1, 0]),
        ('relu', [True]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [10, 6, 5, 5, 1, 0]),
        ('relu', [True]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [12, 10, 5, 5, 1, 0]),
        ('relu', [True]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [16, 12, 5, 5, 1, 0]),
        ('relu', [True]),
        ('max_pool2d', [2, 2, 0]),
        ('flatten', []),
        ('linear', [10, 64]),
        ('relu', [True]),
        ('linear', [1, 10]),
        ('sigmoid', []),
        ('reshape', [])
    ]

printlines = []

model = Learner(net_config, imgc=1, imgsz=28)
criterion = nn.BCEWithLogitsLoss()
update_lr = 0.01
meta_lr = 0.00001
n_epochs = 15
n_metaepochs = 2

metaoptimizer = optim.Adam(model.parameters(), lr=meta_lr)

for metaepoch in range(n_metaepochs):

    printlines.append('===============================')
    printlines.append(f'//           Meta-Epoch {metaepoch + 1}       //')    
    printlines.append('===============================')
    print('===============================')
    print(f'//           Meta-Epoch {metaepoch + 1}       //')    
    print('===============================')

    for mi, metabatch in enumerate(metatrain_loader, 0):  #  Meta-step
        print(mi)
        printlines.append(f'{mi} updates at Meta-Level')
        print(f'{mi} updates at Meta-Level')

        running_loss = 0.0  #  At each meta-step, the loss is reset

        # No need to store initial weights

        for pi, problem_loaders in enumerate(metabatch, 0):  #  Problem in the meta-batch

            printlines.append(f'- Problem {pi + 1} -')
            print(f'- Problem {pi + 1} -')

            problem_loader = problem_loaders['train']
            problem_loader_val = problem_loaders['val']
            ref_label = None

            new_weights = model.parameters()

            for epoch in range(n_epochs):  #  Epoch in the problem training

                printlines.append(f'Epoch {epoch + 1}')
                print(f'Epoch {epoch + 1}')

                val_loss = 0.0
                val_accuracy = 0.0

                for i, data in enumerate(problem_loader, 0):  #  Step in the problem

                    inputs_raw, labels_raw = data
                    inputs = 1 - inputs_raw
                    outputs = model(inputs, new_weights)
                    if ref_label is None:
                        ref_label = labels_raw[0]   #  On a new problem (1st step) adjust label mapping
                    labels = process_labels(labels_raw, ref_label)

                    new_weights, loss, accuracy = make_step(model, outputs, labels, update_lr, new_weights)

                    #  As the prediction is intrinsically done with the new weights, no need to actually update the model at the Learning Level

                    printlines.append(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')
                    print(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')

                for iv, datav in enumerate(problem_loader_val):  #  At the end of the training process in an epoch of a problem we compute a whole validation

                    inputs_rawv, labels_rawv = datav
                    inputsv = 1 - inputs_rawv
                    outputsv = model(inputsv, new_weights)
                    labelsv = process_labels(labels_rawv, ref_label)

                    lossv = criterion(outputsv, labelsv)  #  Loss in a validation batch
                    val_loss += lossv.item()
                    val_accuracy += (((1 - outputsv) < outputsv).float() == labelsv).sum()

                printlines.append(f'Epoch {epoch + 1}, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}')  #  Loss and accuracy averaged for all validation batches in the problem, displayed after whole validation
                print(f'Epoch {epoch + 1}, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}')  #  Loss and accuracy averaged for all validation batches in the problem, displayed after whole validation

            running_loss += lossv  #  After all epochs (all training process) in a single problem the validation loss is added

            # Again, no need to update the model to the initial weights 
        
        metastep_loss = running_loss / metabatch_size  #  The added validation losses of all problems in the metabatch are averaged

        metaoptimizer.zero_grad()  #  We perform gradient descent at the Meta-Level over the averaged validation loss
        metastep_loss.backward()
        metaoptimizer.step()

        if (mi + 1) % 1000 == 0:  #  Meta-validation performed every 1000 meta-steps

            printlines.append('META-VALIDATION STEP:')
            print('META-VALIDATION STEP:')

            for mbvi, metabatch_val in enumerate(metaval_loader):  #  Meta-validation meta-step

                if (mbvi + 1) % 10 == 0:

                    printlines.append(f'Validation step {mbvi + 1}')
                    print(f'Validation step {mbvi + 1}')

                for problem_loaders in metabatch_val:  #  Problem in the meta-validation meta-batch

                    problem_loader = problem_loaders['train']
                    problem_loader_val = problem_loaders['val']
                    ref_label = None
                    new_weights = model.parameters()

                    for epoch in range(n_epochs):  #  Epoch in the problem training

                        val_loss = 0.0
                        val_accuracy = 0.0

                        for i, data in enumerate(problem_loader, 0):  #  Step in the problem
                            
                            inputs_raw, labels_raw = data
                            inputs = 1 - inputs_raw
                            outputs = model(inputs)
                            if ref_label is None:
                                ref_label = labels_raw[0]
                            labels = process_labels(labels_raw, ref_label)

                            new_weights, loss, accuracy = make_step(model, outputs, labels, update_lr, new_weights)

                        #    printlines.append(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')

                        for iv, datav in enumerate(problem_loader_val):  #  At the end of the training process in an epoch of a problem we compute a whole validation, as in Meta-Train

                            inputs_rawv, labels_rawv = datav
                            inputsv = 1 - inputs_rawv
                            outputsv = model(inputsv)
                            labelsv = process_labels(labels_rawv, ref_label)
                            
                            lossv = criterion(outputsv, labelsv)
                            val_loss += lossv.item()
                            val_accuracy += (((1 - outputsv) < outputsv).float() == labelsv).sum()

                    
                    if (mbvi + 1) % 10 == 0:

                        printlines.append(f'Last epoch, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}')  # The Meta-Validation only runs for informative matters, so our goal is to have this at the end of each problem (every 10 steps)
                        print(f'Last epoch, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}')  # The Meta-Validation only runs for informative matters, so our goal is to have this at the end of each problem (every 10 steps)

            printlines.append('END OF META-VALIDATION STEP')
            print('END OF META-VALIDATION STEP')

Metabloggism avatar Feb 09 '23 10:02 Metabloggism