pytorch-moonshine
pytorch-moonshine copied to clipboard
Question on function modify_forward
Hi, thanks for your work, I have following two questions: (1)Can you explain the function modify_forward in count_flops.py, I cannot understand the recursive function here? https://github.com/BayesWatch/pytorch-moonshine/blob/master/count_flops.py#L116
def modify_forward(model):
for child in model.children():
if should_measure(child):
def new_forward(m):
def lambda_forward(x):
measure_layer(m, x)
return m.old_forward(x)
return lambda_forward
child.old_forward = child.forward
child.forward = new_forward(child)
else:
modify_forward(child)
I am confused on what the purpose of the following two lines, can you explain a little bit?
child.old_forward = child.forward
child.forward = new_forward(child)
(2) why batch_norm does not contain flops?
elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']:
delta_params = get_layer_param(layer)
Thanks in advance
@gngdb Hi, any hint here would be appreciated Looking forward to your reply