corenet icon indicating copy to clipboard operation
corenet copied to clipboard

Allow applying on all modules, not just immediate children

Open hub-bla opened this issue 10 months ago • 4 comments

I've made nested module selection based on the way the CSS children selector works. By using '>' we can now select nested modules.

Example:

opts = argparse.Namespace(**{"model.freeze_modules": "model1>ins_model"})

inside_model2 = nn.Sequential(
     OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ])
)

inside_model = nn.Sequential(
     OrderedDict([
          ('ins_model', inside_model2),
          ('conv1', nn.Conv2d(1,20,5))
        ])
)

model = nn.Sequential(
     OrderedDict([
          ('model1', inside_model),
          ('conv1', nn.Conv2d(20,64,5)),
          ('conv2', nn.Conv2d(20,64,5))
        ])
)

print(freeze_modules_based_on_opts(opts, model))

returns: example_result

hub-bla avatar Apr 26 '24 17:04 hub-bla

Hi @hub-bla . Thank you for your contribution. Since freeze_modules accepts regex, I'm wondering if a more flexible regex could select nested modules with the existing code? For example, the following regex seems to work for model1>conv1:

import argparse
from collections import OrderedDict
from torch import nn

from corenet.modeling.misc.common import freeze_modules_based_on_opts


opts = argparse.Namespace(**{"model.freeze_modules": r"model1(.*)\.conv1"})

inside_model2 = nn.Sequential(
     OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ])
)

inside_model = nn.Sequential(
     OrderedDict([
          ('ins_model', inside_model2),
          ('conv1', nn.Conv2d(1,20,5))
        ])
)

model = nn.Sequential(
     OrderedDict([
          ('model1', inside_model),
          ('conv1', nn.Conv2d(20,64,5)),
          ('conv2', nn.Conv2d(20,64,5))
        ])
)

print(freeze_modules_based_on_opts(opts, model))

mohammad7t avatar May 02 '24 16:05 mohammad7t

Hi @mohammad7t, I agree that existing code can support that operations using loop on named_parameters. To be honest I didn't checked if it works before I started implementing the enhancement. I followed a comment that is above the loop with named_children. # TODO: allow applying on all modules, not just immediate chidren? How?
The thing that my code does additionally is that it reduces number of logs. For example, If you want to freeze nested module that is made of a lot of nesting, it won't produce log for every parameter that is going to be freezed. Instead, it will only show that the whole module is now frozen.

To be honest I wonder if that loop with named_modules is even neccessary beacuase everything could by done using as you provided. It would also remove this issue

hub-bla avatar May 03 '24 05:05 hub-bla

I see where you are coming from! I agree that the # TODO: allow applying on all modules, not just immediate chidren? How? is confusing. I think the TODO is related to applying force_eval on nested modules: https://github.com/apple/corenet/blob/aaa14a602d22fe3020eb24096483cf2b8c8af4c0/corenet/modeling/misc/common.py#L200-L208

I wonder if that loop with named_modules is even necessary because everything could by done using as you provided.

That's a good question. I think the only reason we need the loop with named_modules is to apply force_eval as mentioned above.

I'm not entirely sure, what the best solution is right now. Let me think a bit more and get back to you. Thinking loudly, I guess we don't need to support the ">" css operator, but the bfs is probably a good idea to address the TODO. What do you think?

Thanks again!

mohammad7t avatar May 03 '24 13:05 mohammad7t

Thank you for clarification! Now, I get it and I agree with everything you said.

Regarding the ">" selector, I'm not sure if it's unnecessary when applying bfs. The way it works now is that the input string is splitted by this symbol and then those chunks that might be a regex expression or not, are then passed to bfs. Without it, we had to split the string by dot symbol which is a special character in regex and this might cause some problems. I'll try to think about that too.

hub-bla avatar May 03 '24 19:05 hub-bla

Hi again, Quick update - I don't have a clean solution for resolving the TODO for skipping nested train/eval calls at the moment. Please feel free to re-open/update this issue in the future if there are updates.

mohammad7t avatar May 26 '24 03:05 mohammad7t