cnaps icon indicating copy to clipboard operation
cnaps copied to clipboard

Meta-BN Implementation

Open dvtailor opened this issue 1 year ago • 3 comments

Hi, is there a Torch implementation of Meta-BN available as introduced in https://arxiv.org/pdf/2003.03284.pdf

dvtailor avatar Jan 22 '24 14:01 dvtailor

Here is a class that implements Meta-BN that is derived from the base class NormalizationLayer in https://github.com/cambridge-mlg/cnaps/blob/master/src/normalization_layers.py#L5.

Apologies, somehow the code formating below is wacy.

class MetaBN(NormalizationLayer): """MetaBN Normalization Layer""" def init(self, num_features): """ Initialize :param num_features: number of channels in the 2D convolutional layer """ super(MetaBN, self).init(num_features) # Variables to store the context moments to use for normalizing the target. self.context_batch_mean = torch.zeros((1, num_features, 1, 1), requires_grad=True) self.context_batch_var = torch.ones((1, num_features, 1, 1), requires_grad=True)

def forward(self, x):
    """
    Normalize activations.
    :param x: input activations
    :return: normalized activations
    """
    if self.training:  # normalize the context and save off the moments
        batch_mean, batch_var = self._compute_batch_moments(x)
        x = self._normalize(x, batch_mean, batch_var)
        self.context_batch_mean = batch_mean
        self.context_batch_var = batch_var
    else:  # normalize the target with the saved moments
        x = self._normalize(x, self.context_batch_mean, self.context_batch_var)

    return x

jfb54 avatar Jan 24 '24 10:01 jfb54

Thanks! I had an issue when using this in combination with copy.deepcopy(model) which is common in MAML implementations when performing finetuning. It leads to the following error:

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

I can get around this by deleting the batch moment attributes prior to a deepcopy, e.g.:

for m in model.modules():
    if isinstance(m, MetaBN):
        delattr(m, 'context_batch_mean')
        delattr(m, 'context_batch_var')

But this does make execution time longer. Also do you see any issues with this?

dvtailor avatar Jan 24 '24 15:01 dvtailor

If you delete them before the copy, you would need to add them back afterwards as these parameters hold crucial state for normalizing the target set with the statistics of the context set.

jfb54 avatar Jan 28 '24 08:01 jfb54