cnaps
cnaps copied to clipboard
Meta-BN Implementation
Hi, is there a Torch implementation of Meta-BN available as introduced in https://arxiv.org/pdf/2003.03284.pdf
Here is a class that implements Meta-BN that is derived from the base class NormalizationLayer in https://github.com/cambridge-mlg/cnaps/blob/master/src/normalization_layers.py#L5.
Apologies, somehow the code formating below is wacy.
class MetaBN(NormalizationLayer): """MetaBN Normalization Layer""" def init(self, num_features): """ Initialize :param num_features: number of channels in the 2D convolutional layer """ super(MetaBN, self).init(num_features) # Variables to store the context moments to use for normalizing the target. self.context_batch_mean = torch.zeros((1, num_features, 1, 1), requires_grad=True) self.context_batch_var = torch.ones((1, num_features, 1, 1), requires_grad=True)
def forward(self, x):
"""
Normalize activations.
:param x: input activations
:return: normalized activations
"""
if self.training: # normalize the context and save off the moments
batch_mean, batch_var = self._compute_batch_moments(x)
x = self._normalize(x, batch_mean, batch_var)
self.context_batch_mean = batch_mean
self.context_batch_var = batch_var
else: # normalize the target with the saved moments
x = self._normalize(x, self.context_batch_mean, self.context_batch_var)
return x
Thanks! I had an issue when using this in combination with copy.deepcopy(model)
which is common in MAML implementations when performing finetuning.
It leads to the following error:
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
I can get around this by deleting the batch moment attributes prior to a deepcopy, e.g.:
for m in model.modules():
if isinstance(m, MetaBN):
delattr(m, 'context_batch_mean')
delattr(m, 'context_batch_var')
But this does make execution time longer. Also do you see any issues with this?
If you delete them before the copy, you would need to add them back afterwards as these parameters hold crucial state for normalizing the target set with the statistics of the context set.