mmsegmentation icon indicating copy to clipboard operation
mmsegmentation copied to clipboard

How to build independent module that should be stacked on decoder?

Open TimoK93 opened this issue 2 years ago • 1 comments

Hello everyone,

I have a problem with multiple solutions for it. To get the best solution, I want to ask you about your opinion!

Problem Formulation

I want to evaluate a small architecture modification in multiple (potential all) segmentation frameworks. This architecture modification will not affect the backend or the decoder-head directly, but will be stacked on top of the decoder-head.

As an example, imagine the following architecture:

The backbone is ResNet50 and the decoder-head is Deeplabv3. The decoder-head produces class-logits that are used to calculate the final loss.

As architecture modification, i now want to add a single FC layer to transform the logits.

Note: Please do not ask why this should be useful. This is just a tractable example and for my research I will use more complex architecture changes.

Now the question: What is the most elegant implementation for this?

One Potential Solution

I could create a new class SpecialNewDecoderHead(BaseDecodeHead) which initiates a Decoder network as a member and uses it to predict data. It could look like this:

@HEADS.register_module()
class SpecialNewHead(BaseDecodeHead):

    def __init__(self, decoder_parameter=dict, **kwargs):
        super(SpecialNewHead, self).__init__(**kwargs)
        self.decoder = init_decoder(decoder_parameter)
        self.fc = torch.nn.FullyConnected()

    def forward(self, inputs):
        x = self.decoder(inputs)
        return self.fc(x)

Is there a better way?

Thank you for all comments or hints! Best, Timo

TimoK93 avatar Jul 08 '22 13:07 TimoK93

Hi @TimoK93 Sorry for the late reply. I think the modifications you want are somewhat similar to the BiSeNetV1 and ICNet, this is indeed a convenient solution. https://github.com/open-mmlab/mmsegmentation/blob/85569442b659fa59a51a88449689b2aa1603e4ba/mmseg/models/backbones/bisenetv1.py#L274 https://github.com/open-mmlab/mmsegmentation/blob/85569442b659fa59a51a88449689b2aa1603e4ba/mmseg/models/backbones/icnet.py#L20

xiexinch avatar Jul 10 '22 08:07 xiexinch