segmentation_models icon indicating copy to clipboard operation
segmentation_models copied to clipboard

Multi-task learning

Open ntelo007 opened this issue 5 years ago • 7 comments

Good evening. I would like to use a U-Net with a pre-trained ResNet encoder as a multi-task learning neural network with a multi-task weighted loss function. Could you please guide me through this process? I saw that you already have an implementation of U-Net to do that. Thank you in advance for your time.

ntelo007 avatar Feb 05 '20 13:02 ntelo007

When you say 'multi-task learning' are you referring to multi-class classification, multi-label classification or something entirely different?

@qubvel provides an example for multi-class classification in this notebook. As for a weighted loss function, check out the loss function script; some loss functions allow for weighted classes, while others do not. If you scroll to the very bottom of that page, you can see an example of how loss functions are combined.

To actually come up with the weighted values for each class, check this exchange as a starting point.

JordanMakesMaps avatar Feb 05 '20 14:02 JordanMakesMaps

For starters, thank you for your immediate response. I am talking about Multiple tasks --> Semantics Segmentation + Instance Segmentation + Orientation Learning for example. Every task will have its loss function and I want to build a network that will have separate final layers for those predictions that will individually calculate their losses and then combine them into a united loss function that will converge towards the direction that improves all tasks at once. I saw the example, it's quite nice. Is it easy to add a different model, like U-Net and ResNet encoder pretrained?

ntelo007 avatar Feb 05 '20 14:02 ntelo007

Oh, that sounds like a fun project 😎

Not entirely sure how to go about that without doing more research myself. But as for implementing the U-Net with ResNet encoder it's as easy as:

import segmentation_models as sm
model = sm.Unet('resnet50', encoder_weights='imagenet')

And because the API is structured in the way that it is, the model is just a regular Keras model, so it should be easy to combine it with other models for those other tasks you're interested in.

JordanMakesMaps avatar Feb 05 '20 14:02 JordanMakesMaps

Also, could you please tell me how I can see the details of the model variable that you proposed? I would like to be able to read the differences when I try for e.g. resnet50 and resnet 34. Thanks in advance.

ntelo007 avatar Feb 06 '20 09:02 ntelo007

ResNet is a style of architecture that is made up of Residual blocks or modules, which has skip layers or connections to pass information from earlier in the network to parts in the latter part. This was one way of reducing the issue caused by the exploding/vanishing gradient problem. ResNet34 vs ResNet50 vs ResNet152 (or any other variant) is named so based on the size of the architecture (or, how 'deep' it is). Deeper isn't always better though, it really depends on the difficulty of the problem you're trying to solve and how much data you have. So trial and error is typically how you figure out which one to use.

There are plenty of articles and research papers that go over the differences between different architectures, try starting out here and then continuing down the rabbit hole!

JordanMakesMaps avatar Feb 12 '20 14:02 JordanMakesMaps

ResNet is a style of architecture that is made up of Residual blocks or modules, which has skip layers or connections to pass information from earlier in the network to parts in the latter part. This was one way of reducing the issue caused by the exploding/vanishing gradient problem. ResNet34 vs ResNet50 vs ResNet152 (or any other variant) is named so based on the size of the architecture (or, how 'deep' it is). Deeper isn't always better though, it really depends on the difficulty of the problem you're trying to solve and how much data you have. So trial and error is typically how you figure out which one to use.

There are plenty of articles and research papers that go over the differences between different architectures, try starting out here and then continuing down the rabbit hole!

Thank you for your reply. I know how to read those articles and I found several already. What I want to do is the following: In the end, ResNet has an activation layer that leads to a segmentation output. Instead of 1, can I have several different activation layers that solve different tasks utilizing the same representations learnt? Like adding 2 more layers ?

Could you give me your advice on that ?

ntelo007 avatar Feb 14 '20 19:02 ntelo007

Hi @ntelo007 , have you implemented this already ? Was working on the same implementation of training MTL model.https://repository.tudelft.nl/islandora/object/uuid:21fc20a8-455d-4583-9698-4fea04516f03/datastream/OBJ2/download

rishavroy1264bitmesra avatar Jul 31 '21 12:07 rishavroy1264bitmesra