HowToTrainYourMAMLPytorch
HowToTrainYourMAMLPytorch copied to clipboard
line 244 in meta_neural_network_architectures.py
Thank you for releasing the code.
I notice that the function
def forward(self, input, num_step, params=None, training=False, backup_running_statistics=False)
has a training indicator. However, within the function (line 244):
output = F.batch_norm(input, running_mean, running_var, weight, bias,
training=True, momentum=momentum, eps=self.eps)
should the training be always set to true? Does this affect the reported results in the original paper, as batch norm per step appears to be an important trick for improving maml from the paper?
Many thanks.
This is the same as issue: https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/3
My understanding is that this will affect the results reported in the paper. The code as written will always use the batch statistics, not a running average accumulated per step.
What you stated was what I thought was the case. However, after doing a few tests I found that what I stated previously was the right way to go. Check for yourself.
On Fri, 10 Apr 2020 at 16:11, jfb54 [email protected] wrote:
This is the same as issue #3 https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/3: #3 https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/3
My understanding is that this will affect the results reported in the paper. The code as written will always use the batch statistics, not a running average accumulated per step.
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/24#issuecomment-612071685, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACSK4NVZNHHH3IKFGAQK4XDRL4ZJZANCNFSM4I6H2MBQ .
Thanks for the quick response! The following is a short script that demonstrates my assertion. If you have tests that show otherwise, it would be great to see them.
import torch
import torch.nn.functional as F
N = 64 # batch size
C = 16 # number of channels
H = 32 # image height
W = 32 # image width
eps = 1e-05
input = 10 * torch.randn(N, C, H, W) # create a random input
running_mean = torch.zeros(C) # set the running mean for all channels to be 0
running_var = torch.ones(C) # set the running var for all channels to be 1
# Call batch norm with training=False. Expect that the input is normalized with the running mean and running variance
output = F.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05)
# Assert that the output is equal to the input
assert torch.allclose(input, output)
# Call batch norm with training=True. Expect that the input is normalized with batch statistics of the input.
output_bn = F.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=True, momentum=0.1, eps=eps)
# Normalize the input manually
batch_mean = torch.mean(input, dim=(0, 2, 3), keepdim=True)
batch_var = torch.var(input, dim=(0, 2, 3), keepdim=True)
output_manual = (input - batch_mean) / torch.sqrt(batch_var + eps)
# Assert that output_bn equals output_manual
assert torch.allclose(output_bn, output_manual)
I can definitely confirm that it was the case back in 2018. I'll need to reconfirm with the latest versions of pytorch. Will come back to you soon.
On Fri, 10 Apr 2020 at 17:43, jfb54 [email protected] wrote:
Thanks for the quick response! The following is a short script that demonstrates my assertion. If you have tests that show otherwise, it would be great to see them.
import torch import torch.nn.functional as F
N = 64 # batch size C = 16 # number of channels H = 32 # image height W = 32 # image width eps = 1e-05
input = 10 * torch.randn(N, C, H, W) # create a random input
running_mean = torch.zeros(C) # set the running mean for all channels to be 0 running_var = torch.ones(C) # set the running mean for all channels to be 1
Call batch norm with training=False. Expect that the input is normalized with the running mean and running variance
output = F.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05)
Assert that the output is equal to the input
assert torch.allclose(input, output)
Call batch norm with training=True. Expect that the input is normalized with batch statistics of the input.
output_bn = F.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=True, momentum=0.1, eps=eps)
Normalize the input manually
batch_mean = torch.mean(input, dim=(0, 2, 3), keepdim=True) batch_var = torch.var(input, dim=(0, 2, 3), keepdim=True) output_manual = (input - batch_mean) / torch.sqrt(batch_var + eps)
Assert that output_bn equals output_manual
assert torch.allclose(output_bn, output_manual)
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch/issues/24#issuecomment-612114523, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACSK4NVCK3KIKFFIOQ65YHTRL5EE5ANCNFSM4I6H2MBQ .